sonusai 0.19.9__py3-none-any.whl → 0.20.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/calc_metric_spenh.py +265 -233
- sonusai/data/genmixdb.yml +4 -2
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/doc/doc.py +14 -0
- sonusai/genft.py +1 -1
- sonusai/genmetrics.py +15 -18
- sonusai/genmix.py +1 -1
- sonusai/genmixdb.py +30 -52
- sonusai/ir_metric.py +555 -0
- sonusai/metrics_summary.py +322 -0
- sonusai/mixture/__init__.py +6 -2
- sonusai/mixture/audio.py +139 -15
- sonusai/mixture/augmentation.py +199 -84
- sonusai/mixture/config.py +9 -4
- sonusai/mixture/constants.py +0 -1
- sonusai/mixture/datatypes.py +19 -10
- sonusai/mixture/generation.py +52 -64
- sonusai/mixture/helpers.py +38 -26
- sonusai/mixture/ir_delay.py +63 -0
- sonusai/mixture/mixdb.py +190 -46
- sonusai/mixture/targets.py +3 -6
- sonusai/mixture/truth_functions/energy.py +9 -5
- sonusai/mixture/truth_functions/metrics.py +1 -1
- sonusai/mkwav.py +1 -1
- sonusai/onnx_predict.py +1 -1
- sonusai/queries/queries.py +1 -1
- sonusai/utils/__init__.py +2 -0
- sonusai/utils/asr.py +1 -1
- sonusai/utils/load_object.py +8 -2
- sonusai/utils/stratified_shuffle_split.py +1 -1
- sonusai/utils/temp_seed.py +13 -0
- {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/METADATA +2 -2
- {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/RECORD +36 -35
- {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/WHEEL +1 -1
- sonusai/mixture/soundfile_audio.py +0 -130
- sonusai/mixture/sox_audio.py +0 -476
- sonusai/mixture/sox_augmentation.py +0 -136
- sonusai/mixture/torchaudio_audio.py +0 -106
- sonusai/mixture/torchaudio_augmentation.py +0 -109
- {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py
CHANGED
@@ -61,7 +61,7 @@ def db_connection(
|
|
61
61
|
if not create and readonly:
|
62
62
|
name += "?mode=ro"
|
63
63
|
|
64
|
-
connection = sqlite3.connect("file:" + name, uri=True)
|
64
|
+
connection = sqlite3.connect("file:" + name, uri=True, timeout=20)
|
65
65
|
|
66
66
|
if verbose:
|
67
67
|
connection.set_trace_callback(print)
|
@@ -84,7 +84,7 @@ class SQLiteContextManager:
|
|
84
84
|
|
85
85
|
|
86
86
|
class MixtureDatabase:
|
87
|
-
def __init__(self, location: str, test: bool = False) -> None:
|
87
|
+
def __init__(self, location: str, test: bool = False, use_cache: bool = True) -> None:
|
88
88
|
import json
|
89
89
|
from os.path import exists
|
90
90
|
|
@@ -92,6 +92,7 @@ class MixtureDatabase:
|
|
92
92
|
|
93
93
|
self.location = location
|
94
94
|
self.test = test
|
95
|
+
self.use_cache = use_cache
|
95
96
|
|
96
97
|
if not exists(db_file(self.location, self.test)):
|
97
98
|
raise OSError(f"Could not find mixture database in {self.location}")
|
@@ -121,7 +122,7 @@ class MixtureDatabase:
|
|
121
122
|
class_weights_threshold=self.class_weights_thresholds,
|
122
123
|
feature=self.feature,
|
123
124
|
impulse_response_files=self.impulse_response_files,
|
124
|
-
mixtures=self.mixtures,
|
125
|
+
mixtures=self.mixtures(),
|
125
126
|
noise_mix_mode=self.noise_mix_mode,
|
126
127
|
noise_files=self.noise_files,
|
127
128
|
num_classes=self.num_classes,
|
@@ -254,6 +255,16 @@ class MixtureDatabase:
|
|
254
255
|
"Predicted rating of overall quality of mixture versus true targets",
|
255
256
|
),
|
256
257
|
MetricDoc("Mixture Metrics", "ssnr", "Segmental SNR"),
|
258
|
+
MetricDoc("Mixture Metrics", "mxdco", "Mixture DC offset"),
|
259
|
+
MetricDoc("Mixture Metrics", "mxmin", "Mixture min level"),
|
260
|
+
MetricDoc("Mixture Metrics", "mxmax", "Mixture max levl"),
|
261
|
+
MetricDoc("Mixture Metrics", "mxpkdb", "Mixture Pk lev dB"),
|
262
|
+
MetricDoc("Mixture Metrics", "mxlrms", "Mixture RMS lev dB"),
|
263
|
+
MetricDoc("Mixture Metrics", "mxpkr", "Mixture RMS Pk dB"),
|
264
|
+
MetricDoc("Mixture Metrics", "mxtr", "Mixture RMS Tr dB"),
|
265
|
+
MetricDoc("Mixture Metrics", "mxcr", "Mixture Crest factor"),
|
266
|
+
MetricDoc("Mixture Metrics", "mxfl", "Mixture Flat factor"),
|
267
|
+
MetricDoc("Mixture Metrics", "mxpkc", "Mixture Pk count"),
|
257
268
|
MetricDoc("Mixture Metrics", "mxtdco", "Mixture target DC offset"),
|
258
269
|
MetricDoc("Mixture Metrics", "mxtmin", "Mixture target min level"),
|
259
270
|
MetricDoc("Mixture Metrics", "mxtmax", "Mixture target max levl"),
|
@@ -488,7 +499,7 @@ class MixtureDatabase:
|
|
488
499
|
return truth_configs
|
489
500
|
|
490
501
|
def target_truth_configs(self, t_id: int) -> TruthConfigs:
|
491
|
-
return _target_truth_configs(self.db, t_id)
|
502
|
+
return _target_truth_configs(self.db, t_id, self.use_cache)
|
492
503
|
|
493
504
|
@cached_property
|
494
505
|
def random_snrs(self) -> list[float]:
|
@@ -556,7 +567,7 @@ class MixtureDatabase:
|
|
556
567
|
:param sm_id: Spectral mask ID
|
557
568
|
:return: Spectral mask
|
558
569
|
"""
|
559
|
-
return _spectral_mask(self.db, sm_id)
|
570
|
+
return _spectral_mask(self.db, sm_id, self.use_cache)
|
560
571
|
|
561
572
|
@cached_property
|
562
573
|
def target_files(self) -> list[TargetFile]:
|
@@ -619,7 +630,7 @@ class MixtureDatabase:
|
|
619
630
|
:param t_id: Target file ID
|
620
631
|
:return: Target file
|
621
632
|
"""
|
622
|
-
return _target_file(self.db, t_id)
|
633
|
+
return _target_file(self.db, t_id, self.use_cache)
|
623
634
|
|
624
635
|
@cached_property
|
625
636
|
def num_target_files(self) -> int:
|
@@ -657,7 +668,7 @@ class MixtureDatabase:
|
|
657
668
|
:param n_id: Noise file ID
|
658
669
|
:return: Noise file
|
659
670
|
"""
|
660
|
-
return _noise_file(self.db, n_id)
|
671
|
+
return _noise_file(self.db, n_id, self.use_cache)
|
661
672
|
|
662
673
|
@cached_property
|
663
674
|
def num_noise_files(self) -> int:
|
@@ -680,7 +691,7 @@ class MixtureDatabase:
|
|
680
691
|
|
681
692
|
with self.db() as c:
|
682
693
|
return [
|
683
|
-
ImpulseResponseFile(impulse_response[1], json.loads(impulse_response[2]))
|
694
|
+
ImpulseResponseFile(impulse_response[1], json.loads(impulse_response[2]), impulse_response[3])
|
684
695
|
for impulse_response in c.execute(
|
685
696
|
"SELECT impulse_response_file.* FROM impulse_response_file"
|
686
697
|
).fetchall()
|
@@ -699,14 +710,24 @@ class MixtureDatabase:
|
|
699
710
|
]
|
700
711
|
|
701
712
|
def impulse_response_file(self, ir_id: int | None) -> str | None:
|
702
|
-
"""Get impulse response file with ID from db
|
713
|
+
"""Get impulse response file name with ID from db
|
714
|
+
|
715
|
+
:param ir_id: Impulse response file ID
|
716
|
+
:return: Impulse response file name
|
717
|
+
"""
|
718
|
+
if ir_id is None:
|
719
|
+
return None
|
720
|
+
return _impulse_response_file(self.db, ir_id, self.use_cache)
|
721
|
+
|
722
|
+
def impulse_response_delay(self, ir_id: int | None) -> int | None:
|
723
|
+
"""Get impulse response delay with ID from db
|
703
724
|
|
704
725
|
:param ir_id: Impulse response file ID
|
705
|
-
:return:
|
726
|
+
:return: Impulse response delay
|
706
727
|
"""
|
707
728
|
if ir_id is None:
|
708
729
|
return None
|
709
|
-
return
|
730
|
+
return _impulse_response_delay(self.db, ir_id, self.use_cache)
|
710
731
|
|
711
732
|
@cached_property
|
712
733
|
def num_impulse_response_files(self) -> int:
|
@@ -717,7 +738,6 @@ class MixtureDatabase:
|
|
717
738
|
with self.db() as c:
|
718
739
|
return int(c.execute("SELECT count(impulse_response_file.id) FROM impulse_response_file").fetchone()[0])
|
719
740
|
|
720
|
-
@cached_property
|
721
741
|
def mixtures(self) -> list[Mixture]:
|
722
742
|
"""Get mixtures from db
|
723
743
|
|
@@ -760,7 +780,7 @@ class MixtureDatabase:
|
|
760
780
|
:param m_id: Zero-based mixture ID
|
761
781
|
:return: Mixture record
|
762
782
|
"""
|
763
|
-
return _mixture(self.db, m_id)
|
783
|
+
return _mixture(self.db, m_id, self.use_cache)
|
764
784
|
|
765
785
|
@cached_property
|
766
786
|
def mixid_width(self) -> int:
|
@@ -805,7 +825,7 @@ class MixtureDatabase:
|
|
805
825
|
"""
|
806
826
|
from .audio import read_audio
|
807
827
|
|
808
|
-
return read_audio(self.target_file(t_id).name)
|
828
|
+
return read_audio(self.target_file(t_id).name, self.use_cache)
|
809
829
|
|
810
830
|
def augmented_noise_audio(self, mixture: Mixture) -> AudioT:
|
811
831
|
"""Get augmented noise audio
|
@@ -814,18 +834,11 @@ class MixtureDatabase:
|
|
814
834
|
:return: Augmented noise audio
|
815
835
|
"""
|
816
836
|
from .audio import read_audio
|
817
|
-
from .audio import read_ir
|
818
837
|
from .augmentation import apply_augmentation
|
819
|
-
from .augmentation import apply_impulse_response
|
820
838
|
|
821
839
|
noise = self.noise_file(mixture.noise.file_id)
|
822
|
-
audio = read_audio(noise.name)
|
823
|
-
audio = apply_augmentation(audio, mixture.noise.augmentation)
|
824
|
-
if mixture.noise.augmentation.ir is not None:
|
825
|
-
audio = apply_impulse_response(
|
826
|
-
audio,
|
827
|
-
read_ir(self.impulse_response_file(mixture.noise.augmentation.ir)),
|
828
|
-
)
|
840
|
+
audio = read_audio(noise.name, self.use_cache)
|
841
|
+
audio = apply_augmentation(self, audio, mixture.noise.augmentation.pre)
|
829
842
|
|
830
843
|
return audio
|
831
844
|
|
@@ -859,8 +872,9 @@ class MixtureDatabase:
|
|
859
872
|
for target in mixture.targets:
|
860
873
|
target_audio = self.read_target_audio(target.file_id)
|
861
874
|
target_audio = apply_augmentation(
|
875
|
+
mixdb=self,
|
862
876
|
audio=target_audio,
|
863
|
-
augmentation=target.augmentation,
|
877
|
+
augmentation=target.augmentation.pre,
|
864
878
|
frame_length=self.feature_step_samples,
|
865
879
|
)
|
866
880
|
target_audio = apply_gain(audio=target_audio, gain=mixture.target_snr_gain)
|
@@ -1119,8 +1133,7 @@ class MixtureDatabase:
|
|
1119
1133
|
offsets = range(0, mixture.samples, self.ft_config.overlap)
|
1120
1134
|
if len(target_energy) != len(offsets):
|
1121
1135
|
raise ValueError(
|
1122
|
-
f"Number of frames in energy, {len(target_energy)},"
|
1123
|
-
f" is not number of frames in mixture, {len(offsets)}"
|
1136
|
+
f"Number of frames in energy, {len(target_energy)}, is not number of frames in mixture, {len(offsets)}"
|
1124
1137
|
)
|
1125
1138
|
|
1126
1139
|
for idx, offset in enumerate(offsets):
|
@@ -1332,7 +1345,7 @@ class MixtureDatabase:
|
|
1332
1345
|
return sorted(set(self.speaker_metadata_tiers + self.textgrid_metadata_tiers))
|
1333
1346
|
|
1334
1347
|
def speaker(self, s_id: int | None, tier: str) -> str | None:
|
1335
|
-
return _speaker(self.db, s_id, tier)
|
1348
|
+
return _speaker(self.db, s_id, tier, self.use_cache)
|
1336
1349
|
|
1337
1350
|
def speech_metadata(self, tier: str) -> list[str]:
|
1338
1351
|
from .helpers import get_textgrid_tier_from_target_file
|
@@ -1370,11 +1383,11 @@ class MixtureDatabase:
|
|
1370
1383
|
# Check for tempo augmentation and adjust Interval start and end data as needed
|
1371
1384
|
entries = []
|
1372
1385
|
for entry in data:
|
1373
|
-
if target.augmentation.tempo is not None:
|
1386
|
+
if target.augmentation.pre.tempo is not None:
|
1374
1387
|
entries.append(
|
1375
1388
|
Interval(
|
1376
|
-
entry.start / target.augmentation.tempo,
|
1377
|
-
entry.end / target.augmentation.tempo,
|
1389
|
+
entry.start / target.augmentation.pre.tempo,
|
1390
|
+
entry.end / target.augmentation.pre.tempo,
|
1378
1391
|
entry.label,
|
1379
1392
|
)
|
1380
1393
|
)
|
@@ -1464,7 +1477,7 @@ class MixtureDatabase:
|
|
1464
1477
|
|
1465
1478
|
return sorted(result)
|
1466
1479
|
|
1467
|
-
def mixture_metrics(self, m_id: int, metrics: list[str], force: bool = False) ->
|
1480
|
+
def mixture_metrics(self, m_id: int, metrics: list[str], force: bool = False) -> dict[str, Any]:
|
1468
1481
|
"""Get metrics data for the given mixture ID
|
1469
1482
|
|
1470
1483
|
:param m_id: Zero-based mixture ID
|
@@ -1595,6 +1608,19 @@ class MixtureDatabase:
|
|
1595
1608
|
|
1596
1609
|
speech = create_speech()
|
1597
1610
|
|
1611
|
+
def create_mixture_stats() -> Callable[[], AudioStatsMetrics]:
|
1612
|
+
state: AudioStatsMetrics | None = None
|
1613
|
+
|
1614
|
+
def get() -> AudioStatsMetrics:
|
1615
|
+
nonlocal state
|
1616
|
+
if state is None:
|
1617
|
+
state = calc_audio_stats(mixture_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
|
1618
|
+
return state
|
1619
|
+
|
1620
|
+
return get
|
1621
|
+
|
1622
|
+
mixture_stats = create_mixture_stats()
|
1623
|
+
|
1598
1624
|
def create_targets_stats() -> Callable[[], list[AudioStatsMetrics]]:
|
1599
1625
|
state: list[AudioStatsMetrics] | None = None
|
1600
1626
|
|
@@ -1803,6 +1829,36 @@ class MixtureDatabase:
|
|
1803
1829
|
extended=False,
|
1804
1830
|
)
|
1805
1831
|
|
1832
|
+
if m == "mxdco":
|
1833
|
+
return mixture_stats().dco
|
1834
|
+
|
1835
|
+
if m == "mxmin":
|
1836
|
+
return mixture_stats().min
|
1837
|
+
|
1838
|
+
if m == "mxmax":
|
1839
|
+
return mixture_stats().max
|
1840
|
+
|
1841
|
+
if m == "mxpkdb":
|
1842
|
+
return mixture_stats().pkdb
|
1843
|
+
|
1844
|
+
if m == "mxlrms":
|
1845
|
+
return mixture_stats().lrms
|
1846
|
+
|
1847
|
+
if m == "mxpkr":
|
1848
|
+
return mixture_stats().pkr
|
1849
|
+
|
1850
|
+
if m == "mxtr":
|
1851
|
+
return mixture_stats().tr
|
1852
|
+
|
1853
|
+
if m == "mxcr":
|
1854
|
+
return mixture_stats().cr
|
1855
|
+
|
1856
|
+
if m == "mxfl":
|
1857
|
+
return mixture_stats().fl
|
1858
|
+
|
1859
|
+
if m == "mxpkc":
|
1860
|
+
return mixture_stats().pkc
|
1861
|
+
|
1806
1862
|
if m == "mxtdco":
|
1807
1863
|
return target_stats().dco
|
1808
1864
|
|
@@ -1916,21 +1972,34 @@ class MixtureDatabase:
|
|
1916
1972
|
|
1917
1973
|
raise AttributeError(f"Unrecognized metric: '{m}'")
|
1918
1974
|
|
1919
|
-
result:
|
1975
|
+
result: dict[str, Any] = {}
|
1920
1976
|
for metric in metrics:
|
1921
|
-
result
|
1977
|
+
result[metric] = calc(metric)
|
1978
|
+
|
1979
|
+
# Check for metrics dependencies and add them even if not explicitly requested.
|
1980
|
+
if metric.startswith("mxwer"):
|
1981
|
+
dependencies = ("mxasr." + metric[6:], "tasr." + metric[6:])
|
1982
|
+
for dependency in dependencies:
|
1983
|
+
result[dependency] = calc(dependency)
|
1922
1984
|
|
1923
1985
|
return result
|
1924
1986
|
|
1925
1987
|
|
1926
|
-
|
1927
|
-
def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
|
1988
|
+
def _spectral_mask(db: partial, sm_id: int, use_cache: bool = True) -> SpectralMask:
|
1928
1989
|
"""Get spectral mask with ID from db
|
1929
1990
|
|
1930
1991
|
:param db: Database context
|
1931
1992
|
:param sm_id: Spectral mask ID
|
1993
|
+
:param use_cache: If true, use LRU caching
|
1932
1994
|
:return: Spectral mask
|
1933
1995
|
"""
|
1996
|
+
if use_cache:
|
1997
|
+
return __spectral_mask(db, sm_id)
|
1998
|
+
return __spectral_mask.__wrapped__(db, sm_id)
|
1999
|
+
|
2000
|
+
|
2001
|
+
@lru_cache
|
2002
|
+
def __spectral_mask(db: partial, sm_id: int) -> SpectralMask:
|
1934
2003
|
from .db_datatypes import SpectralMaskRecord
|
1935
2004
|
|
1936
2005
|
with db() as c:
|
@@ -1953,12 +2022,26 @@ def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
|
|
1953
2022
|
)
|
1954
2023
|
|
1955
2024
|
|
2025
|
+
def _target_file(db: partial, t_id: int, use_cache: bool = True) -> TargetFile:
|
2026
|
+
"""Get target file with ID from db
|
2027
|
+
|
2028
|
+
:param db: Database context
|
2029
|
+
:param t_id: Target file ID
|
2030
|
+
:param use_cache: If true, use LRU caching
|
2031
|
+
:return: Target file
|
2032
|
+
"""
|
2033
|
+
if use_cache:
|
2034
|
+
return __target_file(db, t_id, use_cache)
|
2035
|
+
return __target_file.__wrapped__(db, t_id, use_cache)
|
2036
|
+
|
2037
|
+
|
1956
2038
|
@lru_cache
|
1957
|
-
def
|
2039
|
+
def __target_file(db: partial, t_id: int, use_cache: bool = True) -> TargetFile:
|
1958
2040
|
"""Get target file with ID from db
|
1959
2041
|
|
1960
2042
|
:param db: Database context
|
1961
2043
|
:param t_id: Target file ID
|
2044
|
+
:param use_cache: If true, use LRU caching
|
1962
2045
|
:return: Target file
|
1963
2046
|
"""
|
1964
2047
|
import json
|
@@ -1982,19 +2065,26 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
|
|
1982
2065
|
samples=target_file.samples,
|
1983
2066
|
class_indices=json.loads(target_file.class_indices),
|
1984
2067
|
level_type=target_file.level_type,
|
1985
|
-
truth_configs=_target_truth_configs(db, t_id),
|
2068
|
+
truth_configs=_target_truth_configs(db, t_id, use_cache),
|
1986
2069
|
speaker_id=target_file.speaker_id,
|
1987
2070
|
)
|
1988
2071
|
|
1989
2072
|
|
1990
|
-
|
1991
|
-
def _noise_file(db: partial, n_id: int) -> NoiseFile:
|
2073
|
+
def _noise_file(db: partial, n_id: int, use_cache: bool = True) -> NoiseFile:
|
1992
2074
|
"""Get noise file with ID from db
|
1993
2075
|
|
1994
2076
|
:param db: Database context
|
1995
2077
|
:param n_id: Noise file ID
|
2078
|
+
:param use_cache: If true, use LRU caching
|
1996
2079
|
:return: Noise file
|
1997
2080
|
"""
|
2081
|
+
if use_cache:
|
2082
|
+
return __noise_file(db, n_id)
|
2083
|
+
return __noise_file.__wrapped__(db, n_id)
|
2084
|
+
|
2085
|
+
|
2086
|
+
@lru_cache
|
2087
|
+
def __noise_file(db: partial, n_id: int) -> NoiseFile:
|
1998
2088
|
with db() as c:
|
1999
2089
|
noise = c.execute(
|
2000
2090
|
"""
|
@@ -2007,14 +2097,21 @@ def _noise_file(db: partial, n_id: int) -> NoiseFile:
|
|
2007
2097
|
return NoiseFile(name=noise[0], samples=noise[1])
|
2008
2098
|
|
2009
2099
|
|
2010
|
-
|
2011
|
-
|
2012
|
-
"""Get impulse response file with ID from db
|
2100
|
+
def _impulse_response_file(db: partial, ir_id: int, use_cache: bool = True) -> str:
|
2101
|
+
"""Get impulse response file name with ID from db
|
2013
2102
|
|
2014
2103
|
:param db: Database context
|
2015
2104
|
:param ir_id: Impulse response file ID
|
2016
|
-
:
|
2105
|
+
:param use_cache: If true, use LRU caching
|
2106
|
+
:return: Impulse response file name
|
2017
2107
|
"""
|
2108
|
+
if use_cache:
|
2109
|
+
return __impulse_response_file(db, ir_id)
|
2110
|
+
return __impulse_response_file.__wrapped__(db, ir_id)
|
2111
|
+
|
2112
|
+
|
2113
|
+
@lru_cache
|
2114
|
+
def __impulse_response_file(db: partial, ir_id: int) -> str:
|
2018
2115
|
with db() as c:
|
2019
2116
|
return str(
|
2020
2117
|
c.execute(
|
@@ -2028,14 +2125,49 @@ def _impulse_response_file(db: partial, ir_id: int) -> str:
|
|
2028
2125
|
)
|
2029
2126
|
|
2030
2127
|
|
2128
|
+
def _impulse_response_delay(db: partial, ir_id: int, use_cache: bool = True) -> int:
|
2129
|
+
"""Get impulse response delay with ID from db
|
2130
|
+
|
2131
|
+
:param db: Database context
|
2132
|
+
:param ir_id: Impulse response file ID
|
2133
|
+
:param use_cache: If true, use LRU caching
|
2134
|
+
:return: Impulse response delay
|
2135
|
+
"""
|
2136
|
+
if use_cache:
|
2137
|
+
return __impulse_response_delay(db, ir_id)
|
2138
|
+
return __impulse_response_delay.__wrapped__(db, ir_id)
|
2139
|
+
|
2140
|
+
|
2031
2141
|
@lru_cache
|
2032
|
-
def
|
2142
|
+
def __impulse_response_delay(db: partial, ir_id: int) -> int:
|
2143
|
+
with db() as c:
|
2144
|
+
return int(
|
2145
|
+
c.execute(
|
2146
|
+
"""
|
2147
|
+
SELECT impulse_response_file.delay
|
2148
|
+
FROM impulse_response_file
|
2149
|
+
WHERE ? = impulse_response_file.id
|
2150
|
+
""",
|
2151
|
+
(ir_id + 1,),
|
2152
|
+
).fetchone()[0]
|
2153
|
+
)
|
2154
|
+
|
2155
|
+
|
2156
|
+
def _mixture(db: partial, m_id: int, use_cache: bool = True) -> Mixture:
|
2033
2157
|
"""Get mixture record with ID from db
|
2034
2158
|
|
2035
2159
|
:param db: Database context
|
2036
2160
|
:param m_id: Zero-based mixture ID
|
2161
|
+
:param use_cache: If true, use LRU caching
|
2037
2162
|
:return: Mixture record
|
2038
2163
|
"""
|
2164
|
+
if use_cache:
|
2165
|
+
return __mixture(db, m_id)
|
2166
|
+
return __mixture.__wrapped__(db, m_id)
|
2167
|
+
|
2168
|
+
|
2169
|
+
@lru_cache
|
2170
|
+
def __mixture(db: partial, m_id: int) -> Mixture:
|
2039
2171
|
from .db_datatypes import MixtureRecord
|
2040
2172
|
from .db_datatypes import TargetRecord
|
2041
2173
|
from .helpers import to_mixture
|
@@ -2068,8 +2200,14 @@ def _mixture(db: partial, m_id: int) -> Mixture:
|
|
2068
2200
|
return to_mixture(mixture, targets)
|
2069
2201
|
|
2070
2202
|
|
2203
|
+
def _speaker(db: partial, s_id: int | None, tier: str, use_cache: bool = True) -> str | None:
|
2204
|
+
if use_cache:
|
2205
|
+
return __speaker(db, s_id, tier)
|
2206
|
+
return __speaker.__wrapped__(db, s_id, tier)
|
2207
|
+
|
2208
|
+
|
2071
2209
|
@lru_cache
|
2072
|
-
def
|
2210
|
+
def __speaker(db: partial, s_id: int | None, tier: str) -> str | None:
|
2073
2211
|
if s_id is None:
|
2074
2212
|
return None
|
2075
2213
|
|
@@ -2082,8 +2220,14 @@ def _speaker(db: partial, s_id: int | None, tier: str) -> str | None:
|
|
2082
2220
|
return data[0]
|
2083
2221
|
|
2084
2222
|
|
2223
|
+
def _target_truth_configs(db: partial, t_id: int, use_cache: bool = True) -> TruthConfigs:
|
2224
|
+
if use_cache:
|
2225
|
+
return __target_truth_configs(db, t_id)
|
2226
|
+
return __target_truth_configs.__wrapped__(db, t_id)
|
2227
|
+
|
2228
|
+
|
2085
2229
|
@lru_cache
|
2086
|
-
def
|
2230
|
+
def __target_truth_configs(db: partial, t_id: int) -> TruthConfigs:
|
2087
2231
|
import json
|
2088
2232
|
|
2089
2233
|
from .datatypes import TruthConfig
|
sonusai/mixture/targets.py
CHANGED
@@ -16,14 +16,11 @@ def get_augmented_targets(
|
|
16
16
|
|
17
17
|
augmented_targets: list[AugmentedTarget] = []
|
18
18
|
for mixup in mixups:
|
19
|
-
|
19
|
+
target_augmentation_indices = get_augmentation_indices_for_mixup(target_augmentations, mixup)
|
20
20
|
for target_index in range(len(target_files)):
|
21
|
-
for
|
21
|
+
for target_augmentation_index in target_augmentation_indices:
|
22
22
|
augmented_targets.append(
|
23
|
-
AugmentedTarget(
|
24
|
-
target_id=target_index,
|
25
|
-
target_augmentation_id=augmentation_index,
|
26
|
-
)
|
23
|
+
AugmentedTarget(target_id=target_index, target_augmentation_id=target_augmentation_index)
|
27
24
|
)
|
28
25
|
|
29
26
|
return augmented_targets
|
@@ -13,6 +13,7 @@ def _core(
|
|
13
13
|
parameters: int,
|
14
14
|
mapped: bool,
|
15
15
|
snr: bool,
|
16
|
+
use_cache: bool = True,
|
16
17
|
) -> Truth:
|
17
18
|
from os.path import join
|
18
19
|
|
@@ -50,8 +51,8 @@ def _core(
|
|
50
51
|
tmp = np.nan_to_num(tmp, nan=-np.inf, posinf=np.inf, neginf=-np.inf)
|
51
52
|
|
52
53
|
if mapped:
|
53
|
-
snr_db_mean = load_object(join(mixdb.location, config["snr_db_mean"]))
|
54
|
-
snr_db_std = load_object(join(mixdb.location, config["snr_db_std"]))
|
54
|
+
snr_db_mean = load_object(join(mixdb.location, config["snr_db_mean"]), use_cache)
|
55
|
+
snr_db_std = load_object(join(mixdb.location, config["snr_db_std"]), use_cache)
|
55
56
|
tmp = _calculate_mapped_snr_f(tmp, snr_db_mean, snr_db_std)
|
56
57
|
|
57
58
|
truth[frame] = tmp
|
@@ -85,7 +86,7 @@ def energy_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
|
85
86
|
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
86
87
|
|
87
88
|
|
88
|
-
def energy_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
|
89
|
+
def energy_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict, use_cache: bool = True) -> Truth:
|
89
90
|
"""Frequency domain energy truth generation function
|
90
91
|
|
91
92
|
Calculates the true energy per bin:
|
@@ -104,6 +105,7 @@ def energy_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict)
|
|
104
105
|
parameters=energy_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
105
106
|
mapped=False,
|
106
107
|
snr=False,
|
108
|
+
use_cache=use_cache,
|
107
109
|
)
|
108
110
|
|
109
111
|
|
@@ -118,7 +120,7 @@ def snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
|
118
120
|
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
119
121
|
|
120
122
|
|
121
|
-
def snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
|
123
|
+
def snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict, use_cache: bool = True) -> Truth:
|
122
124
|
"""Frequency domain SNR truth function documentation
|
123
125
|
|
124
126
|
Calculates the true SNR per bin:
|
@@ -137,6 +139,7 @@ def snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) ->
|
|
137
139
|
parameters=snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
138
140
|
mapped=False,
|
139
141
|
snr=True,
|
142
|
+
use_cache=use_cache,
|
140
143
|
)
|
141
144
|
|
142
145
|
|
@@ -156,7 +159,7 @@ def mapped_snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> i
|
|
156
159
|
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
157
160
|
|
158
161
|
|
159
|
-
def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
|
162
|
+
def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict, use_cache: bool = True) -> Truth:
|
160
163
|
"""Frequency domain mapped SNR truth function documentation
|
161
164
|
|
162
165
|
Output shape: [:, bins]
|
@@ -169,6 +172,7 @@ def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: d
|
|
169
172
|
parameters=mapped_snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
170
173
|
mapped=True,
|
171
174
|
snr=True,
|
175
|
+
use_cache=use_cache,
|
172
176
|
)
|
173
177
|
|
174
178
|
|
sonusai/mkwav.py
CHANGED
@@ -63,7 +63,7 @@ def _process_mixture(m_id: int, location: str, write_target: bool, write_targets
|
|
63
63
|
if write_noise:
|
64
64
|
write_audio(name=join(location, "noise.wav"), audio=float_to_int16(mixdb.mixture_noise(m_id)))
|
65
65
|
|
66
|
-
write_mixture_metadata(mixdb, m_id)
|
66
|
+
write_mixture_metadata(mixdb, m_id=m_id)
|
67
67
|
|
68
68
|
|
69
69
|
def main() -> None:
|
sonusai/onnx_predict.py
CHANGED
@@ -193,7 +193,7 @@ def main() -> None:
|
|
193
193
|
# run inference, ort session wants i.e. batch x timesteps x feat_params, outputs numpy BxTxFP or BxFP
|
194
194
|
predict = session.run(out_names, {in0name: feature})[0]
|
195
195
|
# predict, _ = reshape_outputs(predict=predict[0], timesteps=frames) # frames x feat_params
|
196
|
-
output_fname = join(output_dir, mixdb.
|
196
|
+
output_fname = join(output_dir, mixdb.mixture(mixid).name)
|
197
197
|
with h5py.File(output_fname, "a") as f:
|
198
198
|
if "predict" in f:
|
199
199
|
del f["predict"]
|
sonusai/queries/queries.py
CHANGED
@@ -178,7 +178,7 @@ def get_mixids_from_snr(
|
|
178
178
|
result: dict[float, list[int]] = {}
|
179
179
|
for snr in snrs:
|
180
180
|
# Get a list of mixids for each SNR
|
181
|
-
result[snr] = sorted([i for i, mixture in enumerate(mixdb.mixtures) if mixture.snr == snr and i in mixid_out])
|
181
|
+
result[snr] = sorted([i for i, mixture in enumerate(mixdb.mixtures()) if mixture.snr == snr and i in mixid_out])
|
182
182
|
|
183
183
|
return result
|
184
184
|
|
sonusai/utils/__init__.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
# SonusAI general utilities
|
2
2
|
# ruff: noqa: F401
|
3
|
+
|
3
4
|
from .asl_p56 import asl_p56
|
4
5
|
from .asr import ASRResult
|
5
6
|
from .asr import calc_asr
|
@@ -53,5 +54,6 @@ from .stacked_complex import stacked_complex_imag
|
|
53
54
|
from .stacked_complex import stacked_complex_real
|
54
55
|
from .stacked_complex import unstack_complex
|
55
56
|
from .stratified_shuffle_split import stratified_shuffle_split_mixid
|
57
|
+
from .temp_seed import temp_seed
|
56
58
|
from .write_audio import write_audio
|
57
59
|
from .yes_or_no import yes_or_no
|
sonusai/utils/asr.py
CHANGED
@@ -65,7 +65,7 @@ def calc_asr(audio: AudioT | str, engine: str, **config) -> ASRResult:
|
|
65
65
|
from sonusai.mixture import read_audio
|
66
66
|
|
67
67
|
if not isinstance(audio, np.ndarray):
|
68
|
-
audio = copy(read_audio(audio))
|
68
|
+
audio = copy(read_audio(audio, config.get("use_cache", True)))
|
69
69
|
|
70
70
|
return _asr_fn(engine)(audio, **config)
|
71
71
|
|
sonusai/utils/load_object.py
CHANGED
@@ -2,9 +2,15 @@ from functools import lru_cache
|
|
2
2
|
from typing import Any
|
3
3
|
|
4
4
|
|
5
|
+
def load_object(name: str, use_cache: bool = True) -> Any:
|
6
|
+
"""Load an object from a pickle file"""
|
7
|
+
if use_cache:
|
8
|
+
return _load_object(name)
|
9
|
+
return _load_object.__wrapped__(name)
|
10
|
+
|
11
|
+
|
5
12
|
@lru_cache
|
6
|
-
def
|
7
|
-
"""Load an object from a pickle file (with LRU caching)"""
|
13
|
+
def _load_object(name: str) -> Any:
|
8
14
|
import pickle
|
9
15
|
from os.path import exists
|
10
16
|
|
@@ -42,7 +42,7 @@ def stratified_shuffle_split_mixid(
|
|
42
42
|
raise ValueError("vsplit must be between 0 and 1")
|
43
43
|
|
44
44
|
a_class_mixid: dict[int, list[int]] = {i + 1: [] for i in range(mixdb.num_classes)}
|
45
|
-
for mixid, mixture in enumerate(mixdb.mixtures):
|
45
|
+
for mixid, mixture in enumerate(mixdb.mixtures()):
|
46
46
|
class_count = get_class_count_from_mixids(mixdb, mixid)
|
47
47
|
if any(class_count):
|
48
48
|
for class_index in mixdb.target_files[mixture.targets[0].file_id].class_indices:
|