sonusai 0.18.9__py3-none-any.whl → 0.19.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +81 -91
- sonusai/genmetrics.py +51 -61
- sonusai/genmix.py +105 -115
- sonusai/genmixdb.py +201 -174
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +16 -18
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +20 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +40 -38
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +669 -477
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +58 -101
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +41 -30
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/METADATA +20 -21
- sonusai-0.19.6.dist-info/RECORD +125 -0
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.9.dist-info/RECORD +0 -125
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/entry_points.txt +0 -0
sonusai/genmixdb.py
CHANGED
@@ -76,8 +76,8 @@ generation functions. By default, these are included with the feature data in a
|
|
76
76
|
truth generation is turned on with default settings (see truth section) and a single class, i.e., detecting a single
|
77
77
|
type of sound. The truth format is a single float per class representing the probability of activity/presence, and
|
78
78
|
multi-class truth is possible by specifying the number of classes and either a scalar index or a vector of indices in
|
79
|
-
which to put the truth result. For example, 'num_class: 3' and '
|
80
|
-
with truth put in index 2 (others would be 0) for data/target.wav being an audio clip from sound type of class 2.
|
79
|
+
which to put the truth result. For example, 'num_class: 3' and 'class_indices: [ 2 ]' adds a 1x3 vector to the feature
|
80
|
+
data with truth put in index 2 (others would be 0) for data/target.wav being an audio clip from sound type of class 2.
|
81
81
|
|
82
82
|
The mixture is created with potential data augmentation functions in the following way:
|
83
83
|
1. apply noise augmentation rule
|
@@ -112,11 +112,10 @@ targets:
|
|
112
112
|
will find all .wav files in the specified directories and process them as targets.
|
113
113
|
|
114
114
|
"""
|
115
|
+
|
115
116
|
import signal
|
116
|
-
from dataclasses import dataclass
|
117
117
|
|
118
118
|
from sonusai.mixture import Mixture
|
119
|
-
from sonusai.mixture import MixtureDatabase
|
120
119
|
|
121
120
|
|
122
121
|
def signal_handler(_sig, _frame):
|
@@ -124,43 +123,33 @@ def signal_handler(_sig, _frame):
|
|
124
123
|
|
125
124
|
from sonusai import logger
|
126
125
|
|
127
|
-
logger.info(
|
126
|
+
logger.info("Canceled due to keyboard interrupt")
|
128
127
|
sys.exit(1)
|
129
128
|
|
130
129
|
|
131
130
|
signal.signal(signal.SIGINT, signal_handler)
|
132
131
|
|
133
132
|
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
def genmixdb(location: str,
|
146
|
-
save_mix: bool = False,
|
147
|
-
save_ft: bool = False,
|
148
|
-
save_segsnr: bool = False,
|
149
|
-
logging: bool = True,
|
150
|
-
show_progress: bool = False,
|
151
|
-
test: bool = False,
|
152
|
-
save_json: bool = False) -> MixtureDatabase:
|
133
|
+
def genmixdb(
|
134
|
+
location: str,
|
135
|
+
save_mix: bool = False,
|
136
|
+
save_ft: bool = False,
|
137
|
+
save_segsnr: bool = False,
|
138
|
+
logging: bool = True,
|
139
|
+
show_progress: bool = False,
|
140
|
+
test: bool = False,
|
141
|
+
save_json: bool = False,
|
142
|
+
) -> None:
|
143
|
+
from functools import partial
|
153
144
|
from random import seed
|
154
145
|
|
155
146
|
import yaml
|
156
|
-
from tqdm import tqdm
|
157
147
|
|
158
|
-
from sonusai import SonusAIError
|
159
148
|
from sonusai import logger
|
160
|
-
from sonusai.mixture import AugmentationRule
|
161
|
-
from sonusai.mixture import MixtureDatabase
|
162
149
|
from sonusai.mixture import SAMPLE_BYTES
|
163
150
|
from sonusai.mixture import SAMPLE_RATE
|
151
|
+
from sonusai.mixture import AugmentationRule
|
152
|
+
from sonusai.mixture import MixtureDatabase
|
164
153
|
from sonusai.mixture import balance_targets
|
165
154
|
from sonusai.mixture import generate_mixtures
|
166
155
|
from sonusai.mixture import get_all_snrs_from_config
|
@@ -182,11 +171,13 @@ def genmixdb(location: str,
|
|
182
171
|
from sonusai.mixture import populate_spectral_mask_table
|
183
172
|
from sonusai.mixture import populate_target_file_table
|
184
173
|
from sonusai.mixture import populate_top_table
|
174
|
+
from sonusai.mixture import populate_truth_parameters_table
|
185
175
|
from sonusai.mixture import update_mixid_width
|
186
176
|
from sonusai.utils import dataclass_from_dict
|
187
177
|
from sonusai.utils import human_readable_size
|
188
|
-
from sonusai.utils import
|
178
|
+
from sonusai.utils import par_track
|
189
179
|
from sonusai.utils import seconds_to_hms
|
180
|
+
from sonusai.utils import track
|
190
181
|
|
191
182
|
config = load_config(location)
|
192
183
|
initialize_db(location=location, test=test)
|
@@ -197,113 +188,116 @@ def genmixdb(location: str,
|
|
197
188
|
populate_class_label_table(location, config, test)
|
198
189
|
populate_class_weights_threshold_table(location, config, test)
|
199
190
|
populate_spectral_mask_table(location, config, test)
|
191
|
+
populate_truth_parameters_table(location, config, test)
|
200
192
|
|
201
|
-
seed(config[
|
193
|
+
seed(config["seed"])
|
202
194
|
|
203
195
|
if logging:
|
204
|
-
logger.debug(f
|
205
|
-
logger.debug(
|
196
|
+
logger.debug(f"Seed: {config['seed']}")
|
197
|
+
logger.debug("Configuration:")
|
206
198
|
logger.debug(yaml.dump(config))
|
207
199
|
|
208
200
|
if logging:
|
209
|
-
logger.info(
|
201
|
+
logger.info("Collecting targets")
|
210
202
|
|
211
203
|
target_files = get_target_files(config, show_progress=show_progress)
|
212
204
|
|
213
205
|
if len(target_files) == 0:
|
214
|
-
raise
|
206
|
+
raise RuntimeError("Canceled due to no targets")
|
215
207
|
|
216
208
|
populate_target_file_table(location, target_files, test)
|
217
209
|
|
218
210
|
if logging:
|
219
|
-
logger.debug(
|
211
|
+
logger.debug("List of targets:")
|
220
212
|
logger.debug(yaml.dump([target.name for target in mixdb.target_files], default_flow_style=False))
|
221
|
-
logger.debug(
|
213
|
+
logger.debug("")
|
222
214
|
|
223
215
|
if logging:
|
224
|
-
logger.info(
|
216
|
+
logger.info("Collecting noises")
|
225
217
|
|
226
218
|
noise_files = get_noise_files(config, show_progress=show_progress)
|
227
219
|
|
228
220
|
populate_noise_file_table(location, noise_files, test)
|
229
221
|
|
230
222
|
if logging:
|
231
|
-
logger.debug(
|
223
|
+
logger.debug("List of noises:")
|
232
224
|
logger.debug(yaml.dump([noise.name for noise in mixdb.noise_files], default_flow_style=False))
|
233
|
-
logger.debug(
|
225
|
+
logger.debug("")
|
234
226
|
|
235
227
|
if logging:
|
236
|
-
logger.info(
|
228
|
+
logger.info("Collecting impulse responses")
|
237
229
|
|
238
230
|
impulse_response_files = get_impulse_response_files(config)
|
239
231
|
|
240
232
|
populate_impulse_response_file_table(location, impulse_response_files, test)
|
241
233
|
|
242
234
|
if logging:
|
243
|
-
logger.debug(
|
235
|
+
logger.debug("List of impulse responses:")
|
244
236
|
logger.debug(
|
245
|
-
yaml.dump(
|
246
|
-
|
247
|
-
|
237
|
+
yaml.dump(
|
238
|
+
[entry.file for entry in mixdb.impulse_response_files],
|
239
|
+
default_flow_style=False,
|
240
|
+
)
|
241
|
+
)
|
242
|
+
logger.debug("")
|
248
243
|
|
249
244
|
if logging:
|
250
|
-
logger.info(
|
245
|
+
logger.info("Collecting target augmentations")
|
251
246
|
|
252
|
-
target_augmentations = get_augmentation_rules(
|
253
|
-
|
247
|
+
target_augmentations = get_augmentation_rules(
|
248
|
+
rules=config["target_augmentations"], num_ir=mixdb.num_impulse_response_files
|
249
|
+
)
|
254
250
|
mixups = get_mixups(target_augmentations)
|
255
251
|
|
256
252
|
if logging:
|
257
253
|
for mixup in mixups:
|
258
|
-
logger.debug(f
|
254
|
+
logger.debug(f"Expanded list of target augmentation rules for mixup of {mixup}:")
|
259
255
|
for target_augmentation in get_target_augmentations_for_mixup(target_augmentations, mixup):
|
260
256
|
ta_dict = target_augmentation.to_dict()
|
261
|
-
del ta_dict[
|
262
|
-
logger.debug(f
|
263
|
-
logger.debug(
|
257
|
+
del ta_dict["mixup"]
|
258
|
+
logger.debug(f"- {ta_dict}")
|
259
|
+
logger.debug("")
|
264
260
|
|
265
261
|
if logging:
|
266
|
-
logger.info(
|
262
|
+
logger.info("Collecting noise augmentations")
|
267
263
|
|
268
|
-
noise_augmentations = get_augmentation_rules(
|
269
|
-
|
264
|
+
noise_augmentations = get_augmentation_rules(
|
265
|
+
rules=config["noise_augmentations"], num_ir=mixdb.num_impulse_response_files
|
266
|
+
)
|
270
267
|
|
271
268
|
if logging:
|
272
|
-
logger.debug(
|
269
|
+
logger.debug("Expanded list of noise augmentations:")
|
273
270
|
for noise_augmentation in noise_augmentations:
|
274
271
|
na_dict = noise_augmentation.to_dict()
|
275
|
-
del na_dict[
|
276
|
-
logger.debug(f
|
277
|
-
logger.debug(
|
272
|
+
del na_dict["mixup"]
|
273
|
+
logger.debug(f"- {na_dict}")
|
274
|
+
logger.debug("")
|
278
275
|
|
279
276
|
if logging:
|
280
|
-
logger.debug(f
|
281
|
-
logger.debug(f
|
282
|
-
logger.debug(f
|
283
|
-
logger.debug(
|
277
|
+
logger.debug(f"SNRs: {config['snrs']}\n")
|
278
|
+
logger.debug(f"Random SNRs: {config['random_snrs']}\n")
|
279
|
+
logger.debug(f"Noise mix mode: {mixdb.noise_mix_mode}\n")
|
280
|
+
logger.debug("Spectral masks:")
|
284
281
|
for spectral_mask in mixdb.spectral_masks:
|
285
|
-
logger.debug(f
|
286
|
-
logger.debug(
|
287
|
-
|
288
|
-
if mixdb.truth_mutex and any(mixup > 1 for mixup in mixups):
|
289
|
-
raise SonusAIError(f'Mutex truth mode is not compatible with mixup')
|
282
|
+
logger.debug(f"- {spectral_mask}")
|
283
|
+
logger.debug("")
|
290
284
|
|
291
285
|
if logging:
|
292
|
-
logger.info(
|
286
|
+
logger.info("Collecting augmented targets")
|
293
287
|
|
294
288
|
augmented_targets = get_augmented_targets(target_files, target_augmentations, mixups)
|
295
289
|
|
296
|
-
if config[
|
297
|
-
class_balancing_augmentation = dataclass_from_dict(AugmentationRule, config[
|
290
|
+
if config["class_balancing"]:
|
291
|
+
class_balancing_augmentation = dataclass_from_dict(AugmentationRule, config["class_balancing_augmentation"])
|
298
292
|
augmented_targets, target_augmentations = balance_targets(
|
299
293
|
augmented_targets=augmented_targets,
|
300
294
|
targets=target_files,
|
301
295
|
target_augmentations=target_augmentations,
|
302
296
|
class_balancing_augmentation=class_balancing_augmentation,
|
303
297
|
num_classes=mixdb.num_classes,
|
304
|
-
truth_mutex=mixdb.truth_mutex,
|
305
298
|
num_ir=mixdb.num_impulse_response_files,
|
306
|
-
mixups=mixups
|
299
|
+
mixups=mixups,
|
300
|
+
)
|
307
301
|
|
308
302
|
target_audio_samples = sum([targets.samples for targets in mixdb.target_files])
|
309
303
|
target_audio_duration = target_audio_samples / SAMPLE_RATE
|
@@ -311,13 +305,17 @@ def genmixdb(location: str,
|
|
311
305
|
noise_audio_samples = noise_audio_duration * SAMPLE_RATE
|
312
306
|
|
313
307
|
if logging:
|
314
|
-
logger.info(
|
315
|
-
logger.info(
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
308
|
+
logger.info("")
|
309
|
+
logger.info(
|
310
|
+
f"Target audio: {mixdb.num_target_files} files, "
|
311
|
+
f"{human_readable_size(target_audio_samples * SAMPLE_BYTES, 1)}, "
|
312
|
+
f"{seconds_to_hms(seconds=target_audio_duration)}"
|
313
|
+
)
|
314
|
+
logger.info(
|
315
|
+
f"Noise audio: {mixdb.num_noise_files} files, "
|
316
|
+
f"{human_readable_size(noise_audio_samples * SAMPLE_BYTES, 1)}, "
|
317
|
+
f"{seconds_to_hms(seconds=noise_audio_duration)}"
|
318
|
+
)
|
321
319
|
|
322
320
|
used_noise_files, used_noise_samples, mixtures = generate_mixtures(
|
323
321
|
noise_mix_mode=mixdb.noise_mix_mode,
|
@@ -330,41 +328,53 @@ def genmixdb(location: str,
|
|
330
328
|
all_snrs=get_all_snrs_from_config(config),
|
331
329
|
mixups=mixups,
|
332
330
|
num_classes=mixdb.num_classes,
|
333
|
-
truth_mutex=mixdb.truth_mutex,
|
334
331
|
feature_step_samples=mixdb.feature_step_samples,
|
335
|
-
num_ir=mixdb.num_impulse_response_files
|
332
|
+
num_ir=mixdb.num_impulse_response_files,
|
333
|
+
)
|
336
334
|
|
337
335
|
num_mixtures = len(mixtures)
|
338
336
|
update_mixid_width(location, num_mixtures, test)
|
339
337
|
|
340
338
|
if logging:
|
341
|
-
logger.info(
|
342
|
-
logger.info(f
|
339
|
+
logger.info("")
|
340
|
+
logger.info(f"Found {num_mixtures:,} mixtures to process")
|
343
341
|
|
344
342
|
total_duration = float(sum([mixture.samples for mixture in mixtures])) / SAMPLE_RATE
|
345
343
|
|
346
344
|
if logging:
|
347
|
-
log_duration_and_sizes(
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
|
357
|
-
|
358
|
-
|
345
|
+
log_duration_and_sizes(
|
346
|
+
total_duration=total_duration,
|
347
|
+
num_classes=mixdb.num_classes,
|
348
|
+
feature_step_samples=mixdb.feature_step_samples,
|
349
|
+
feature_parameters=mixdb.feature_parameters,
|
350
|
+
stride=mixdb.fg_stride,
|
351
|
+
desc="Estimated",
|
352
|
+
)
|
353
|
+
logger.info(
|
354
|
+
f"Feature shape: "
|
355
|
+
f"{mixdb.fg_stride} x {mixdb.feature_parameters} "
|
356
|
+
f"({mixdb.fg_stride * mixdb.feature_parameters} total params)"
|
357
|
+
)
|
358
|
+
logger.info(f"Feature samples: {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)")
|
359
|
+
logger.info(f"Feature step samples: {mixdb.feature_step_samples} samples ({mixdb.feature_step_ms} ms)")
|
360
|
+
logger.info("")
|
359
361
|
|
360
362
|
# Fill in the details
|
361
363
|
if logging:
|
362
|
-
logger.info(
|
363
|
-
progress =
|
364
|
-
mixtures =
|
365
|
-
|
366
|
-
|
367
|
-
|
364
|
+
logger.info("Generating mixtures")
|
365
|
+
progress = track(total=num_mixtures, disable=not show_progress)
|
366
|
+
mixtures = par_track(
|
367
|
+
partial(
|
368
|
+
_process_mixture,
|
369
|
+
location=location,
|
370
|
+
save_mix=save_mix,
|
371
|
+
save_ft=save_ft,
|
372
|
+
save_segsnr=save_segsnr,
|
373
|
+
test=test,
|
374
|
+
),
|
375
|
+
mixtures,
|
376
|
+
progress=progress,
|
377
|
+
)
|
368
378
|
progress.close()
|
369
379
|
|
370
380
|
populate_mixture_table(location, mixtures, test)
|
@@ -378,77 +388,88 @@ def genmixdb(location: str,
|
|
378
388
|
noise_samples_percent = (float(used_noise_samples) / float(noise_audio_samples)) * 100
|
379
389
|
|
380
390
|
if logging:
|
381
|
-
log_duration_and_sizes(
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
|
389
|
-
logger.info(
|
390
|
-
logger.info(
|
391
|
+
log_duration_and_sizes(
|
392
|
+
total_duration=total_duration,
|
393
|
+
num_classes=mixdb.num_classes,
|
394
|
+
feature_step_samples=mixdb.feature_step_samples,
|
395
|
+
feature_parameters=mixdb.feature_parameters,
|
396
|
+
stride=mixdb.fg_stride,
|
397
|
+
desc="Actual",
|
398
|
+
)
|
399
|
+
logger.info("")
|
400
|
+
logger.info(f"Used {noise_files_percent:,.0f}% of noise files")
|
401
|
+
logger.info(f"Used {noise_samples_percent:,.0f}% of noise audio")
|
402
|
+
logger.info("")
|
391
403
|
|
392
404
|
if not test and save_json:
|
393
405
|
if logging:
|
394
|
-
logger.info(f
|
406
|
+
logger.info(f"Writing JSON version of database to {location}")
|
395
407
|
mixdb = MixtureDatabase(location)
|
396
408
|
mixdb.save()
|
397
409
|
|
398
|
-
return mixdb
|
399
|
-
|
400
410
|
|
401
|
-
def
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
411
|
+
def _process_mixture(
|
412
|
+
mixture: Mixture,
|
413
|
+
location: str,
|
414
|
+
save_mix: bool,
|
415
|
+
save_ft: bool,
|
416
|
+
save_segsnr: bool,
|
417
|
+
test: bool,
|
418
|
+
) -> Mixture:
|
409
419
|
from typing import Any
|
410
420
|
|
421
|
+
from sonusai.mixture import MixtureDatabase
|
411
422
|
from sonusai.mixture import get_ft
|
412
423
|
from sonusai.mixture import get_segsnr
|
413
|
-
from sonusai.mixture import
|
424
|
+
from sonusai.mixture import get_truth
|
414
425
|
from sonusai.mixture import update_mixture
|
415
|
-
from sonusai.mixture import
|
426
|
+
from sonusai.mixture import write_cached_data
|
416
427
|
from sonusai.mixture import write_mixture_metadata
|
417
428
|
|
418
|
-
with_data =
|
419
|
-
mixdb =
|
429
|
+
with_data = save_mix or save_ft
|
430
|
+
mixdb = MixtureDatabase(location, test)
|
420
431
|
|
421
432
|
mixture, genmix_data = update_mixture(mixdb, mixture, with_data)
|
422
433
|
|
423
434
|
if with_data:
|
424
435
|
write_data: list[tuple[str, Any]] = []
|
425
436
|
|
426
|
-
if
|
427
|
-
write_data.append((
|
428
|
-
write_data.append((
|
429
|
-
write_data.append((
|
430
|
-
|
431
|
-
if
|
432
|
-
|
433
|
-
|
434
|
-
|
435
|
-
|
436
|
-
|
437
|
-
|
438
|
-
|
439
|
-
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
437
|
+
if save_mix:
|
438
|
+
write_data.append(("targets", genmix_data.targets))
|
439
|
+
write_data.append(("noise", genmix_data.noise))
|
440
|
+
write_data.append(("mixture", genmix_data.mixture))
|
441
|
+
|
442
|
+
if save_ft:
|
443
|
+
if genmix_data.targets is None or genmix_data.noise is None or genmix_data.mixture is None:
|
444
|
+
raise RuntimeError("Mixture data was not generated properly")
|
445
|
+
truth_t = get_truth(
|
446
|
+
mixdb=mixdb,
|
447
|
+
mixture=mixture,
|
448
|
+
targets_audio=genmix_data.targets,
|
449
|
+
noise_audio=genmix_data.noise,
|
450
|
+
mixture_audio=genmix_data.mixture,
|
451
|
+
)
|
452
|
+
feature, truth_f = get_ft(
|
453
|
+
mixdb=mixdb,
|
454
|
+
mixture=mixture,
|
455
|
+
mixture_audio=genmix_data.mixture,
|
456
|
+
truth_t=truth_t,
|
457
|
+
)
|
458
|
+
write_data.append(("feature", feature))
|
459
|
+
write_data.append(("truth_f", truth_f))
|
460
|
+
|
461
|
+
if save_segsnr:
|
462
|
+
if genmix_data.target is None:
|
463
|
+
raise RuntimeError("Target data was not generated properly")
|
464
|
+
segsnr = get_segsnr(
|
465
|
+
mixdb=mixdb,
|
466
|
+
mixture=mixture,
|
467
|
+
target_audio=genmix_data.target,
|
468
|
+
noise=genmix_data.noise,
|
469
|
+
)
|
470
|
+
write_data.append(("segsnr", segsnr))
|
471
|
+
|
472
|
+
write_cached_data(mixdb.location, "mixture", mixture.name, write_data)
|
452
473
|
write_mixture_metadata(mixdb, mixture)
|
453
474
|
|
454
475
|
return mixture
|
@@ -478,13 +499,13 @@ def main() -> None:
|
|
478
499
|
from sonusai.mixture import load_config
|
479
500
|
from sonusai.utils import seconds_to_hms
|
480
501
|
|
481
|
-
verbose = args[
|
482
|
-
save_mix = args[
|
483
|
-
save_ft = args[
|
484
|
-
save_segsnr = args[
|
485
|
-
dryrun = args[
|
486
|
-
save_json = args[
|
487
|
-
location = args[
|
502
|
+
verbose = args["--verbose"]
|
503
|
+
save_mix = args["--mix"]
|
504
|
+
save_ft = args["--ft"]
|
505
|
+
save_segsnr = args["--segsnr"]
|
506
|
+
dryrun = args["--dryrun"]
|
507
|
+
save_json = args["--json"]
|
508
|
+
location = args["LOC"]
|
488
509
|
|
489
510
|
start_time = time.monotonic()
|
490
511
|
|
@@ -493,30 +514,36 @@ def main() -> None:
|
|
493
514
|
|
494
515
|
makedirs(location, exist_ok=True)
|
495
516
|
|
496
|
-
create_file_handler(join(location,
|
517
|
+
create_file_handler(join(location, "genmixdb.log"))
|
497
518
|
update_console_handler(verbose)
|
498
|
-
initial_log_messages(
|
519
|
+
initial_log_messages("genmixdb")
|
499
520
|
|
500
521
|
if dryrun:
|
501
522
|
config = load_config(location)
|
502
|
-
logger.info(
|
523
|
+
logger.info("Dryrun configuration:")
|
503
524
|
logger.info(yaml.dump(config))
|
504
525
|
return
|
505
526
|
|
506
|
-
logger.info(f
|
507
|
-
logger.info(
|
508
|
-
|
509
|
-
|
510
|
-
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
527
|
+
logger.info(f"Creating mixture database for {location}")
|
528
|
+
logger.info("")
|
529
|
+
|
530
|
+
try:
|
531
|
+
genmixdb(
|
532
|
+
location=location,
|
533
|
+
save_mix=save_mix,
|
534
|
+
save_ft=save_ft,
|
535
|
+
save_segsnr=save_segsnr,
|
536
|
+
show_progress=True,
|
537
|
+
save_json=save_json,
|
538
|
+
)
|
539
|
+
except Exception as e:
|
540
|
+
logger.debug(e)
|
541
|
+
raise
|
515
542
|
|
516
543
|
end_time = time.monotonic()
|
517
|
-
logger.info(f
|
518
|
-
logger.info(
|
544
|
+
logger.info(f"Completed in {seconds_to_hms(seconds=end_time - start_time)}")
|
545
|
+
logger.info("")
|
519
546
|
|
520
547
|
|
521
|
-
if __name__ ==
|
548
|
+
if __name__ == "__main__":
|
522
549
|
main()
|