sonusai 0.19.8__py3-none-any.whl → 0.19.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/calc_metric_spenh.py +265 -233
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/genft.py +1 -1
- sonusai/genmetrics.py +15 -18
- sonusai/genmix.py +1 -1
- sonusai/genmixdb.py +32 -54
- sonusai/metrics_summary.py +320 -0
- sonusai/mixture/__init__.py +2 -1
- sonusai/mixture/audio.py +40 -7
- sonusai/mixture/generation.py +100 -121
- sonusai/mixture/helpers.py +22 -7
- sonusai/mixture/mixdb.py +90 -30
- sonusai/mixture/torchaudio_audio.py +18 -7
- sonusai/mixture/torchaudio_augmentation.py +3 -4
- sonusai/mixture/truth_functions/energy.py +9 -5
- sonusai/mixture/truth_functions/metrics.py +1 -1
- sonusai/mkwav.py +1 -1
- sonusai/onnx_predict.py +1 -1
- sonusai/queries/queries.py +1 -1
- sonusai/utils/asr.py +1 -1
- sonusai/utils/load_object.py +8 -2
- sonusai/utils/stratified_shuffle_split.py +1 -1
- {sonusai-0.19.8.dist-info → sonusai-0.19.10.dist-info}/METADATA +1 -1
- {sonusai-0.19.8.dist-info → sonusai-0.19.10.dist-info}/RECORD +27 -24
- {sonusai-0.19.8.dist-info → sonusai-0.19.10.dist-info}/WHEEL +0 -0
- {sonusai-0.19.8.dist-info → sonusai-0.19.10.dist-info}/entry_points.txt +0 -0
sonusai/mixture/audio.py
CHANGED
@@ -44,44 +44,77 @@ def validate_input_file(input_filepath: str | Path) -> None:
|
|
44
44
|
raise OSError(f"This installation cannot process .{ext} files")
|
45
45
|
|
46
46
|
|
47
|
-
|
48
|
-
def get_sample_rate(name: str | Path) -> int:
|
47
|
+
def get_sample_rate(name: str | Path, use_cache: bool = True) -> int:
|
49
48
|
"""Get sample rate from audio file
|
50
49
|
|
51
50
|
:param name: File name
|
51
|
+
:param use_cache: If true, use LRU caching
|
52
52
|
:return: Sample rate
|
53
53
|
"""
|
54
|
+
if use_cache:
|
55
|
+
return _get_sample_rate(name)
|
56
|
+
return _get_sample_rate.__wrapped__(name)
|
57
|
+
|
58
|
+
|
59
|
+
@lru_cache
|
60
|
+
def _get_sample_rate(name: str | Path) -> int:
|
54
61
|
from .soundfile_audio import get_sample_rate
|
55
62
|
|
56
63
|
return get_sample_rate(name)
|
57
64
|
|
58
65
|
|
59
|
-
|
60
|
-
def read_audio(name: str | Path) -> AudioT:
|
66
|
+
def read_audio(name: str | Path, use_cache: bool = True) -> AudioT:
|
61
67
|
"""Read audio data from a file
|
62
68
|
|
63
69
|
:param name: File name
|
70
|
+
:param use_cache: If true, use LRU caching
|
64
71
|
:return: Array of time domain audio data
|
65
72
|
"""
|
73
|
+
if use_cache:
|
74
|
+
return _read_audio(name)
|
75
|
+
return _read_audio.__wrapped__(name)
|
76
|
+
|
77
|
+
|
78
|
+
@lru_cache
|
79
|
+
def _read_audio(name: str | Path) -> AudioT:
|
66
80
|
from .soundfile_audio import read_audio
|
67
81
|
|
68
82
|
return read_audio(name)
|
69
83
|
|
70
84
|
|
71
|
-
|
72
|
-
def read_ir(name: str | Path) -> ImpulseResponseData:
|
85
|
+
def read_ir(name: str | Path, use_cache: bool = True) -> ImpulseResponseData:
|
73
86
|
"""Read impulse response data
|
74
87
|
|
75
88
|
:param name: File name
|
89
|
+
:param use_cache: If true, use LRU caching
|
76
90
|
:return: ImpulseResponseData object
|
77
91
|
"""
|
92
|
+
if use_cache:
|
93
|
+
return _read_ir(name)
|
94
|
+
return _read_ir.__wrapped__(name)
|
95
|
+
|
96
|
+
|
97
|
+
@lru_cache
|
98
|
+
def _read_ir(name: str | Path) -> ImpulseResponseData:
|
78
99
|
from .soundfile_audio import read_ir
|
79
100
|
|
80
101
|
return read_ir(name)
|
81
102
|
|
82
103
|
|
104
|
+
def get_num_samples(name: str | Path, use_cache: bool = True) -> int:
|
105
|
+
"""Get the number of samples resampled to the SonusAI sample rate in the given file
|
106
|
+
|
107
|
+
:param name: File name
|
108
|
+
:param use_cache: If true, use LRU caching
|
109
|
+
:return: number of samples in resampled audio
|
110
|
+
"""
|
111
|
+
if use_cache:
|
112
|
+
return _get_num_samples(name)
|
113
|
+
return _get_num_samples.__wrapped__(name)
|
114
|
+
|
115
|
+
|
83
116
|
@lru_cache
|
84
|
-
def
|
117
|
+
def _get_num_samples(name: str | Path) -> int:
|
85
118
|
"""Get the number of samples resampled to the SonusAI sample rate in the given file
|
86
119
|
|
87
120
|
:param name: File name
|
sonusai/mixture/generation.py
CHANGED
@@ -119,8 +119,7 @@ def initialize_db(location: str, test: bool = False) -> None:
|
|
119
119
|
id INTEGER PRIMARY KEY NOT NULL,
|
120
120
|
file_id INTEGER NOT NULL,
|
121
121
|
augmentation TEXT NOT NULL,
|
122
|
-
FOREIGN KEY(file_id) REFERENCES target_file (id)
|
123
|
-
UNIQUE(file_id, augmentation))
|
122
|
+
FOREIGN KEY(file_id) REFERENCES target_file (id))
|
124
123
|
""")
|
125
124
|
|
126
125
|
con.execute("""
|
@@ -389,8 +388,7 @@ def update_mixid_width(location: str, num_mixtures: int, test: bool = False) ->
|
|
389
388
|
con.close()
|
390
389
|
|
391
390
|
|
392
|
-
def
|
393
|
-
location: str,
|
391
|
+
def generate_mixtures(
|
394
392
|
noise_mix_mode: str,
|
395
393
|
augmented_targets: list[AugmentedTarget],
|
396
394
|
target_files: list[TargetFile],
|
@@ -403,9 +401,8 @@ def populate_mixture_table(
|
|
403
401
|
num_classes: int,
|
404
402
|
feature_step_samples: int,
|
405
403
|
num_ir: int,
|
406
|
-
|
407
|
-
|
408
|
-
"""Generate mixtures and populate mixture table"""
|
404
|
+
) -> tuple[int, int, list[Mixture]]:
|
405
|
+
"""Generate mixtures"""
|
409
406
|
if noise_mix_mode == "exhaustive":
|
410
407
|
func = _exhaustive_noise_mix
|
411
408
|
elif noise_mix_mode == "non-exhaustive":
|
@@ -415,8 +412,7 @@ def populate_mixture_table(
|
|
415
412
|
else:
|
416
413
|
raise ValueError(f"invalid noise_mix_mode: {noise_mix_mode}")
|
417
414
|
|
418
|
-
|
419
|
-
location=location,
|
415
|
+
return func(
|
420
416
|
augmented_targets=augmented_targets,
|
421
417
|
target_files=target_files,
|
422
418
|
target_augmentations=target_augmentations,
|
@@ -428,23 +424,76 @@ def populate_mixture_table(
|
|
428
424
|
num_classes=num_classes,
|
429
425
|
feature_step_samples=feature_step_samples,
|
430
426
|
num_ir=num_ir,
|
431
|
-
test=test,
|
432
427
|
)
|
433
428
|
|
434
|
-
|
429
|
+
|
430
|
+
def populate_mixture_table(
|
431
|
+
location: str,
|
432
|
+
mixtures: list[Mixture],
|
433
|
+
test: bool = False,
|
434
|
+
logging: bool = False,
|
435
|
+
show_progress: bool = False,
|
436
|
+
) -> None:
|
437
|
+
"""Populate mixture table"""
|
438
|
+
from sonusai import logger
|
439
|
+
from sonusai.utils import track
|
440
|
+
|
441
|
+
from .helpers import from_mixture
|
442
|
+
from .helpers import from_target
|
443
|
+
from .mixdb import db_connection
|
444
|
+
|
445
|
+
con = db_connection(location=location, readonly=False, test=test)
|
446
|
+
|
447
|
+
# Populate target table
|
448
|
+
if logging:
|
449
|
+
logger.info("Populating target table")
|
450
|
+
targets: list[tuple[int, str]] = []
|
451
|
+
for mixture in mixtures:
|
452
|
+
for target in mixture.targets:
|
453
|
+
entry = from_target(target)
|
454
|
+
if entry not in targets:
|
455
|
+
targets.append(entry)
|
456
|
+
for target in track(targets, disable=not show_progress):
|
457
|
+
con.execute("INSERT INTO target (file_id, augmentation) VALUES (?, ?)", target)
|
458
|
+
|
459
|
+
# Populate mixture table
|
460
|
+
if logging:
|
461
|
+
logger.info("Populating mixture table")
|
462
|
+
for mixture in track(mixtures, disable=not show_progress):
|
463
|
+
m_id = int(mixture.name)
|
464
|
+
con.execute(
|
465
|
+
"""
|
466
|
+
INSERT INTO mixture (id, name, noise_file_id, noise_augmentation, noise_offset, noise_snr_gain, random_snr,
|
467
|
+
snr, samples, spectral_mask_id, spectral_mask_seed, target_snr_gain)
|
468
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
469
|
+
""",
|
470
|
+
(m_id + 1, *from_mixture(mixture)),
|
471
|
+
)
|
472
|
+
|
473
|
+
for target in mixture.targets:
|
474
|
+
target_id = con.execute(
|
475
|
+
"""
|
476
|
+
SELECT target.id
|
477
|
+
FROM target
|
478
|
+
WHERE ? = target.file_id AND ? = target.augmentation
|
479
|
+
""",
|
480
|
+
from_target(target),
|
481
|
+
).fetchone()[0]
|
482
|
+
con.execute(
|
483
|
+
"INSERT INTO mixture_target (mixture_id, target_id) VALUES (?, ?)",
|
484
|
+
(m_id + 1, target_id),
|
485
|
+
)
|
486
|
+
|
487
|
+
con.commit()
|
488
|
+
con.close()
|
435
489
|
|
436
490
|
|
437
|
-
def
|
491
|
+
def update_mixture(mixdb: MixtureDatabase, mixture: Mixture, with_data: bool = False) -> tuple[Mixture, GenMixData]:
|
438
492
|
"""Update mixture record with name and gains"""
|
439
493
|
from .audio import get_next_noise
|
440
494
|
from .augmentation import apply_gain
|
441
|
-
from .datatypes import GenMixData
|
442
|
-
from .helpers import from_mixture
|
443
495
|
from .helpers import get_target
|
444
|
-
from .mixdb import db_connection
|
445
496
|
|
446
|
-
mixdb = MixtureDatabase(location, test)
|
447
|
-
mixture = mixdb.mixture(m_id)
|
448
497
|
mixture, targets_audio = _initialize_targets_audio(mixdb, mixture)
|
449
498
|
|
450
499
|
noise_audio = _augmented_noise_audio(mixdb, mixture)
|
@@ -459,29 +508,8 @@ def update_mixture_table(location: str, m_id: int, with_data: bool = False, test
|
|
459
508
|
|
460
509
|
mixture.name = f"{int(mixture.name):0{mixdb.mixid_width}}"
|
461
510
|
|
462
|
-
con = db_connection(location=location, readonly=False, test=test)
|
463
|
-
con.execute(
|
464
|
-
"""
|
465
|
-
UPDATE mixture SET name=?,
|
466
|
-
noise_file_id=?,
|
467
|
-
noise_augmentation=?,
|
468
|
-
noise_offset=?,
|
469
|
-
noise_snr_gain=?,
|
470
|
-
random_snr=?,
|
471
|
-
snr=?,
|
472
|
-
samples=?,
|
473
|
-
spectral_mask_id=?,
|
474
|
-
spectral_mask_seed=?,
|
475
|
-
target_snr_gain=?
|
476
|
-
WHERE ? = mixture.id
|
477
|
-
""",
|
478
|
-
(*from_mixture(mixture), m_id + 1),
|
479
|
-
)
|
480
|
-
con.commit()
|
481
|
-
con.close()
|
482
|
-
|
483
511
|
if not with_data:
|
484
|
-
return GenMixData()
|
512
|
+
return mixture, GenMixData()
|
485
513
|
|
486
514
|
# Apply SNR gains
|
487
515
|
targets_audio = [apply_gain(audio=target_audio, gain=mixture.target_snr_gain) for target_audio in targets_audio]
|
@@ -491,7 +519,7 @@ def update_mixture_table(location: str, m_id: int, with_data: bool = False, test
|
|
491
519
|
target_audio = get_target(mixdb, mixture, targets_audio)
|
492
520
|
mixture_audio = target_audio + noise_audio
|
493
521
|
|
494
|
-
return GenMixData(
|
522
|
+
return mixture, GenMixData(
|
495
523
|
mixture=mixture_audio,
|
496
524
|
targets=targets_audio,
|
497
525
|
target=target_audio,
|
@@ -511,7 +539,7 @@ def _augmented_noise_audio(mixdb: MixtureDatabase, mixture: Mixture) -> AudioT:
|
|
511
539
|
audio = read_audio(noise.name)
|
512
540
|
audio = apply_augmentation(audio, noise_augmentation)
|
513
541
|
if noise_augmentation.ir is not None:
|
514
|
-
audio = apply_impulse_response(audio, read_ir(mixdb.impulse_response_file(noise_augmentation.ir)))
|
542
|
+
audio = apply_impulse_response(audio, read_ir(mixdb.impulse_response_file(noise_augmentation.ir))) # pyright: ignore [reportArgumentType]
|
515
543
|
|
516
544
|
return audio
|
517
545
|
|
@@ -540,7 +568,10 @@ def _initialize_targets_audio(mixdb: MixtureDatabase, mixture: Mixture) -> tuple
|
|
540
568
|
|
541
569
|
|
542
570
|
def _initialize_mixture_gains(
|
543
|
-
mixdb: MixtureDatabase,
|
571
|
+
mixdb: MixtureDatabase,
|
572
|
+
mixture: Mixture,
|
573
|
+
target_audio: AudioT,
|
574
|
+
noise_audio: AudioT,
|
544
575
|
) -> Mixture:
|
545
576
|
import numpy as np
|
546
577
|
|
@@ -603,7 +634,6 @@ def _initialize_mixture_gains(
|
|
603
634
|
|
604
635
|
|
605
636
|
def _exhaustive_noise_mix(
|
606
|
-
location: str,
|
607
637
|
augmented_targets: list[AugmentedTarget],
|
608
638
|
target_files: list[TargetFile],
|
609
639
|
target_augmentations: list[AugmentationRule],
|
@@ -615,9 +645,8 @@ def _exhaustive_noise_mix(
|
|
615
645
|
num_classes: int,
|
616
646
|
feature_step_samples: int,
|
617
647
|
num_ir: int,
|
618
|
-
|
619
|
-
|
620
|
-
"""Use every noise/augmentation with every target/augmentation"""
|
648
|
+
) -> tuple[int, int, list[Mixture]]:
|
649
|
+
"""Use every noise/augmentation with every target/augmentation+interferences/augmentation"""
|
621
650
|
from random import randint
|
622
651
|
|
623
652
|
import numpy as np
|
@@ -643,6 +672,8 @@ def _exhaustive_noise_mix(
|
|
643
672
|
)
|
644
673
|
for mixup in mixups
|
645
674
|
]
|
675
|
+
|
676
|
+
mixtures: list[Mixture] = []
|
646
677
|
for noise_file_id in range(len(noise_files)):
|
647
678
|
for noise_augmentation_rule in noise_augmentations:
|
648
679
|
noise_augmentation = augmentation_from_rule(noise_augmentation_rule, num_ir)
|
@@ -665,10 +696,8 @@ def _exhaustive_noise_mix(
|
|
665
696
|
|
666
697
|
for spectral_mask_id in range(len(spectral_masks)):
|
667
698
|
for snr in all_snrs:
|
668
|
-
|
669
|
-
|
670
|
-
m_id=m_id,
|
671
|
-
mixture=Mixture(
|
699
|
+
mixtures.append(
|
700
|
+
Mixture(
|
672
701
|
targets=targets,
|
673
702
|
name=str(m_id),
|
674
703
|
noise=Noise(file_id=noise_file_id + 1, augmentation=noise_augmentation),
|
@@ -677,19 +706,17 @@ def _exhaustive_noise_mix(
|
|
677
706
|
snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
|
678
707
|
spectral_mask_id=spectral_mask_id + 1,
|
679
708
|
spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
|
680
|
-
)
|
681
|
-
test=test,
|
709
|
+
)
|
682
710
|
)
|
683
711
|
m_id += 1
|
684
712
|
|
685
713
|
noise_offset = int((noise_offset + target_length) % noise_length)
|
686
714
|
used_noise_samples += target_length
|
687
715
|
|
688
|
-
return used_noise_files, used_noise_samples
|
716
|
+
return used_noise_files, used_noise_samples, mixtures
|
689
717
|
|
690
718
|
|
691
719
|
def _non_exhaustive_noise_mix(
|
692
|
-
location: str,
|
693
720
|
augmented_targets: list[AugmentedTarget],
|
694
721
|
target_files: list[TargetFile],
|
695
722
|
target_augmentations: list[AugmentationRule],
|
@@ -701,10 +728,9 @@ def _non_exhaustive_noise_mix(
|
|
701
728
|
num_classes: int,
|
702
729
|
feature_step_samples: int,
|
703
730
|
num_ir: int,
|
704
|
-
|
705
|
-
|
706
|
-
|
707
|
-
(reduced data set).
|
731
|
+
) -> tuple[int, int, list[Mixture]]:
|
732
|
+
"""Cycle through every target/augmentation+interferences/augmentation without necessarily using all
|
733
|
+
noise/augmentation combinations (reduced data set).
|
708
734
|
"""
|
709
735
|
from random import randint
|
710
736
|
|
@@ -732,6 +758,8 @@ def _non_exhaustive_noise_mix(
|
|
732
758
|
)
|
733
759
|
for mixup in mixups
|
734
760
|
]
|
761
|
+
|
762
|
+
mixtures: list[Mixture] = []
|
735
763
|
for mixup in augmented_target_indices_for_mixups:
|
736
764
|
for augmented_target_indices in mixup:
|
737
765
|
targets, target_length = _get_target_info(
|
@@ -763,10 +791,8 @@ def _non_exhaustive_noise_mix(
|
|
763
791
|
|
764
792
|
used_noise_files.add(f"{noise_file_id}_{noise_augmentation_id}")
|
765
793
|
|
766
|
-
|
767
|
-
|
768
|
-
m_id=m_id,
|
769
|
-
mixture=Mixture(
|
794
|
+
mixtures.append(
|
795
|
+
Mixture(
|
770
796
|
targets=targets,
|
771
797
|
name=str(m_id),
|
772
798
|
noise=Noise(file_id=noise_file_id + 1, augmentation=noise_augmentation),
|
@@ -775,16 +801,14 @@ def _non_exhaustive_noise_mix(
|
|
775
801
|
snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
|
776
802
|
spectral_mask_id=spectral_mask_id + 1,
|
777
803
|
spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
|
778
|
-
)
|
779
|
-
test=test,
|
804
|
+
)
|
780
805
|
)
|
781
806
|
m_id += 1
|
782
807
|
|
783
|
-
return len(used_noise_files), used_noise_samples
|
808
|
+
return len(used_noise_files), used_noise_samples, mixtures
|
784
809
|
|
785
810
|
|
786
811
|
def _non_combinatorial_noise_mix(
|
787
|
-
location: str,
|
788
812
|
augmented_targets: list[AugmentedTarget],
|
789
813
|
target_files: list[TargetFile],
|
790
814
|
target_augmentations: list[AugmentationRule],
|
@@ -796,11 +820,10 @@ def _non_combinatorial_noise_mix(
|
|
796
820
|
num_classes: int,
|
797
821
|
feature_step_samples: int,
|
798
822
|
num_ir: int,
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
beginning if end of noise/augmentation is reached.
|
823
|
+
) -> tuple[int, int, list[Mixture]]:
|
824
|
+
"""Combine a target/augmentation+interferences/augmentation with a single cut of a noise/augmentation
|
825
|
+
non-exhaustively (each target/augmentation+interferences/augmentation does not use each noise/augmentation).
|
826
|
+
Cut has random start and loop back to beginning if end of noise/augmentation is reached.
|
804
827
|
"""
|
805
828
|
from random import choice
|
806
829
|
from random import randint
|
@@ -828,6 +851,8 @@ def _non_combinatorial_noise_mix(
|
|
828
851
|
)
|
829
852
|
for mixup in mixups
|
830
853
|
]
|
854
|
+
|
855
|
+
mixtures: list[Mixture] = []
|
831
856
|
for mixup in augmented_target_indices_for_mixups:
|
832
857
|
for augmented_target_indices in mixup:
|
833
858
|
targets, target_length = _get_target_info(
|
@@ -857,10 +882,8 @@ def _non_combinatorial_noise_mix(
|
|
857
882
|
|
858
883
|
used_noise_files.add(f"{noise_file_id}_{noise_augmentation_id}")
|
859
884
|
|
860
|
-
|
861
|
-
|
862
|
-
m_id=m_id,
|
863
|
-
mixture=Mixture(
|
885
|
+
mixtures.append(
|
886
|
+
Mixture(
|
864
887
|
targets=targets,
|
865
888
|
name=str(m_id),
|
866
889
|
noise=Noise(file_id=noise_file_id + 1, augmentation=noise_augmentation),
|
@@ -869,12 +892,11 @@ def _non_combinatorial_noise_mix(
|
|
869
892
|
snr=UniversalSNR(value=snr.value, is_random=snr.is_random),
|
870
893
|
spectral_mask_id=spectral_mask_id + 1,
|
871
894
|
spectral_mask_seed=randint(0, np.iinfo("i").max), # noqa: S311
|
872
|
-
)
|
873
|
-
test=test,
|
895
|
+
)
|
874
896
|
)
|
875
897
|
m_id += 1
|
876
898
|
|
877
|
-
return len(used_noise_files), used_noise_samples
|
899
|
+
return len(used_noise_files), used_noise_samples, mixtures
|
878
900
|
|
879
901
|
|
880
902
|
def _get_next_noise_indices(
|
@@ -973,49 +995,6 @@ def _get_target_info(
|
|
973
995
|
return mixups, target_length
|
974
996
|
|
975
997
|
|
976
|
-
def _insert_mixture_record(location: str, m_id: int, mixture: Mixture, test: bool = False) -> None:
|
977
|
-
from .helpers import from_mixture
|
978
|
-
from .helpers import from_target
|
979
|
-
from .mixdb import db_connection
|
980
|
-
|
981
|
-
con = db_connection(location=location, readonly=False, test=test)
|
982
|
-
# Populate target table
|
983
|
-
for target in mixture.targets:
|
984
|
-
con.execute(
|
985
|
-
"""
|
986
|
-
INSERT OR IGNORE INTO target (file_id, augmentation)
|
987
|
-
VALUES (?, ?)
|
988
|
-
""",
|
989
|
-
from_target(target),
|
990
|
-
)
|
991
|
-
|
992
|
-
# Populate mixture table
|
993
|
-
con.execute(
|
994
|
-
"""
|
995
|
-
INSERT INTO mixture (id, name, noise_file_id, noise_augmentation, noise_offset, noise_snr_gain, random_snr,
|
996
|
-
snr, samples, spectral_mask_id, spectral_mask_seed, target_snr_gain)
|
997
|
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
998
|
-
""",
|
999
|
-
(m_id + 1, *from_mixture(mixture)),
|
1000
|
-
)
|
1001
|
-
|
1002
|
-
for target in mixture.targets:
|
1003
|
-
target_id = con.execute(
|
1004
|
-
"""
|
1005
|
-
SELECT target.id
|
1006
|
-
FROM target
|
1007
|
-
WHERE ? = target.file_id AND ? = target.augmentation
|
1008
|
-
""",
|
1009
|
-
from_target(target),
|
1010
|
-
).fetchone()[0]
|
1011
|
-
con.execute(
|
1012
|
-
"INSERT INTO mixture_target (mixture_id, target_id) VALUES (?, ?)",
|
1013
|
-
(m_id + 1, target_id),
|
1014
|
-
)
|
1015
|
-
con.commit()
|
1016
|
-
con.close()
|
1017
|
-
|
1018
|
-
|
1019
998
|
def get_all_snrs_from_config(config: dict) -> list[UniversalSNRGenerator]:
|
1020
999
|
from .datatypes import UniversalSNRGenerator
|
1021
1000
|
|
sonusai/mixture/helpers.py
CHANGED
@@ -135,14 +135,20 @@ def mixture_all_speech_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> lis
|
|
135
135
|
return results
|
136
136
|
|
137
137
|
|
138
|
-
def mixture_metadata(mixdb: MixtureDatabase, m_id: int) -> str:
|
138
|
+
def mixture_metadata(mixdb: MixtureDatabase, m_id: int | None = None, mixture: Mixture | None = None) -> str:
|
139
139
|
"""Create a string of metadata for a Mixture
|
140
140
|
|
141
141
|
:param mixdb: Mixture database
|
142
142
|
:param m_id: Mixture ID
|
143
|
+
:param mixture: Mixture record
|
143
144
|
:return: String of metadata
|
144
145
|
"""
|
145
|
-
|
146
|
+
if m_id is not None:
|
147
|
+
mixture = mixdb.mixture(m_id)
|
148
|
+
|
149
|
+
if mixture is None:
|
150
|
+
raise ValueError("No mixture specified.")
|
151
|
+
|
146
152
|
metadata = ""
|
147
153
|
speech_metadata = mixture_all_speech_metadata(mixdb, mixture)
|
148
154
|
for mi, target in enumerate(mixture.targets):
|
@@ -173,17 +179,25 @@ def mixture_metadata(mixdb: MixtureDatabase, m_id: int) -> str:
|
|
173
179
|
return metadata
|
174
180
|
|
175
181
|
|
176
|
-
def write_mixture_metadata(mixdb: MixtureDatabase, m_id: int) -> None:
|
182
|
+
def write_mixture_metadata(mixdb: MixtureDatabase, m_id: int | None = None, mixture: Mixture | None = None) -> None:
|
177
183
|
"""Write mixture metadata to a text file
|
178
184
|
|
179
185
|
:param mixdb: Mixture database
|
180
186
|
:param m_id: Mixture ID
|
187
|
+
:param mixture: Mixture record
|
181
188
|
"""
|
182
189
|
from os.path import join
|
183
190
|
|
184
|
-
|
191
|
+
if m_id is not None:
|
192
|
+
name = mixdb.mixture(m_id).name
|
193
|
+
elif mixture is not None:
|
194
|
+
name = mixture.name
|
195
|
+
else:
|
196
|
+
raise ValueError("No mixture specified.")
|
197
|
+
|
198
|
+
name = join(mixdb.location, "mixture", name, "metadata.txt")
|
185
199
|
with open(file=name, mode="w") as f:
|
186
|
-
f.write(mixture_metadata(mixdb, m_id))
|
200
|
+
f.write(mixture_metadata(mixdb, m_id, mixture))
|
187
201
|
|
188
202
|
|
189
203
|
def from_mixture(
|
@@ -246,12 +260,13 @@ def to_target(entry: TargetRecord) -> Target:
|
|
246
260
|
)
|
247
261
|
|
248
262
|
|
249
|
-
def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: list[AudioT]) -> AudioT:
|
263
|
+
def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: list[AudioT], use_cache: bool = True) -> AudioT:
|
250
264
|
"""Get the augmented target audio data for the given mixture record
|
251
265
|
|
252
266
|
:param mixdb: Mixture database
|
253
267
|
:param mixture: Mixture record
|
254
268
|
:param targets_audio: List of augmented target audio data (one per target in the mixup)
|
269
|
+
:param use_cache: If true, use LRU caching
|
255
270
|
:return: Sum of augmented target audio data
|
256
271
|
"""
|
257
272
|
# Apply impulse responses to targets
|
@@ -265,7 +280,7 @@ def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: list[Aud
|
|
265
280
|
ir_idx = mixture.targets[idx].augmentation.ir
|
266
281
|
if ir_idx is not None:
|
267
282
|
targets_ir.append(
|
268
|
-
apply_impulse_response(audio=target, ir=read_ir(mixdb.impulse_response_file(int(ir_idx))))
|
283
|
+
apply_impulse_response(audio=target, ir=read_ir(mixdb.impulse_response_file(int(ir_idx)), use_cache)) # pyright: ignore [reportArgumentType]
|
269
284
|
)
|
270
285
|
else:
|
271
286
|
targets_ir.append(target)
|