sonusai 0.19.8__py3-none-any.whl → 0.19.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/calc_metric_spenh.py +265 -233
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/genft.py +1 -1
- sonusai/genmetrics.py +15 -18
- sonusai/genmix.py +1 -1
- sonusai/genmixdb.py +32 -54
- sonusai/metrics_summary.py +320 -0
- sonusai/mixture/__init__.py +2 -1
- sonusai/mixture/audio.py +40 -7
- sonusai/mixture/generation.py +100 -121
- sonusai/mixture/helpers.py +22 -7
- sonusai/mixture/mixdb.py +90 -30
- sonusai/mixture/torchaudio_audio.py +18 -7
- sonusai/mixture/torchaudio_augmentation.py +3 -4
- sonusai/mixture/truth_functions/energy.py +9 -5
- sonusai/mixture/truth_functions/metrics.py +1 -1
- sonusai/mkwav.py +1 -1
- sonusai/onnx_predict.py +1 -1
- sonusai/queries/queries.py +1 -1
- sonusai/utils/asr.py +1 -1
- sonusai/utils/load_object.py +8 -2
- sonusai/utils/stratified_shuffle_split.py +1 -1
- {sonusai-0.19.8.dist-info → sonusai-0.19.10.dist-info}/METADATA +1 -1
- {sonusai-0.19.8.dist-info → sonusai-0.19.10.dist-info}/RECORD +27 -24
- {sonusai-0.19.8.dist-info → sonusai-0.19.10.dist-info}/WHEEL +0 -0
- {sonusai-0.19.8.dist-info → sonusai-0.19.10.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py
CHANGED
@@ -61,7 +61,7 @@ def db_connection(
|
|
61
61
|
if not create and readonly:
|
62
62
|
name += "?mode=ro"
|
63
63
|
|
64
|
-
connection = sqlite3.connect("file:" + name, uri=True)
|
64
|
+
connection = sqlite3.connect("file:" + name, uri=True, timeout=20)
|
65
65
|
|
66
66
|
if verbose:
|
67
67
|
connection.set_trace_callback(print)
|
@@ -84,7 +84,7 @@ class SQLiteContextManager:
|
|
84
84
|
|
85
85
|
|
86
86
|
class MixtureDatabase:
|
87
|
-
def __init__(self, location: str, test: bool = False) -> None:
|
87
|
+
def __init__(self, location: str, test: bool = False, use_cache: bool = True) -> None:
|
88
88
|
import json
|
89
89
|
from os.path import exists
|
90
90
|
|
@@ -92,6 +92,7 @@ class MixtureDatabase:
|
|
92
92
|
|
93
93
|
self.location = location
|
94
94
|
self.test = test
|
95
|
+
self.use_cache = use_cache
|
95
96
|
|
96
97
|
if not exists(db_file(self.location, self.test)):
|
97
98
|
raise OSError(f"Could not find mixture database in {self.location}")
|
@@ -121,7 +122,7 @@ class MixtureDatabase:
|
|
121
122
|
class_weights_threshold=self.class_weights_thresholds,
|
122
123
|
feature=self.feature,
|
123
124
|
impulse_response_files=self.impulse_response_files,
|
124
|
-
mixtures=self.mixtures,
|
125
|
+
mixtures=self.mixtures(),
|
125
126
|
noise_mix_mode=self.noise_mix_mode,
|
126
127
|
noise_files=self.noise_files,
|
127
128
|
num_classes=self.num_classes,
|
@@ -488,7 +489,7 @@ class MixtureDatabase:
|
|
488
489
|
return truth_configs
|
489
490
|
|
490
491
|
def target_truth_configs(self, t_id: int) -> TruthConfigs:
|
491
|
-
return _target_truth_configs(self.db, t_id)
|
492
|
+
return _target_truth_configs(self.db, t_id, self.use_cache)
|
492
493
|
|
493
494
|
@cached_property
|
494
495
|
def random_snrs(self) -> list[float]:
|
@@ -556,7 +557,7 @@ class MixtureDatabase:
|
|
556
557
|
:param sm_id: Spectral mask ID
|
557
558
|
:return: Spectral mask
|
558
559
|
"""
|
559
|
-
return _spectral_mask(self.db, sm_id)
|
560
|
+
return _spectral_mask(self.db, sm_id, self.use_cache)
|
560
561
|
|
561
562
|
@cached_property
|
562
563
|
def target_files(self) -> list[TargetFile]:
|
@@ -619,7 +620,7 @@ class MixtureDatabase:
|
|
619
620
|
:param t_id: Target file ID
|
620
621
|
:return: Target file
|
621
622
|
"""
|
622
|
-
return _target_file(self.db, t_id)
|
623
|
+
return _target_file(self.db, t_id, self.use_cache)
|
623
624
|
|
624
625
|
@cached_property
|
625
626
|
def num_target_files(self) -> int:
|
@@ -657,7 +658,7 @@ class MixtureDatabase:
|
|
657
658
|
:param n_id: Noise file ID
|
658
659
|
:return: Noise file
|
659
660
|
"""
|
660
|
-
return _noise_file(self.db, n_id)
|
661
|
+
return _noise_file(self.db, n_id, self.use_cache)
|
661
662
|
|
662
663
|
@cached_property
|
663
664
|
def num_noise_files(self) -> int:
|
@@ -706,7 +707,7 @@ class MixtureDatabase:
|
|
706
707
|
"""
|
707
708
|
if ir_id is None:
|
708
709
|
return None
|
709
|
-
return _impulse_response_file(self.db, ir_id)
|
710
|
+
return _impulse_response_file(self.db, ir_id, self.use_cache)
|
710
711
|
|
711
712
|
@cached_property
|
712
713
|
def num_impulse_response_files(self) -> int:
|
@@ -717,7 +718,6 @@ class MixtureDatabase:
|
|
717
718
|
with self.db() as c:
|
718
719
|
return int(c.execute("SELECT count(impulse_response_file.id) FROM impulse_response_file").fetchone()[0])
|
719
720
|
|
720
|
-
@cached_property
|
721
721
|
def mixtures(self) -> list[Mixture]:
|
722
722
|
"""Get mixtures from db
|
723
723
|
|
@@ -760,7 +760,7 @@ class MixtureDatabase:
|
|
760
760
|
:param m_id: Zero-based mixture ID
|
761
761
|
:return: Mixture record
|
762
762
|
"""
|
763
|
-
return _mixture(self.db, m_id)
|
763
|
+
return _mixture(self.db, m_id, self.use_cache)
|
764
764
|
|
765
765
|
@cached_property
|
766
766
|
def mixid_width(self) -> int:
|
@@ -805,7 +805,7 @@ class MixtureDatabase:
|
|
805
805
|
"""
|
806
806
|
from .audio import read_audio
|
807
807
|
|
808
|
-
return read_audio(self.target_file(t_id).name)
|
808
|
+
return read_audio(self.target_file(t_id).name, self.use_cache)
|
809
809
|
|
810
810
|
def augmented_noise_audio(self, mixture: Mixture) -> AudioT:
|
811
811
|
"""Get augmented noise audio
|
@@ -819,12 +819,12 @@ class MixtureDatabase:
|
|
819
819
|
from .augmentation import apply_impulse_response
|
820
820
|
|
821
821
|
noise = self.noise_file(mixture.noise.file_id)
|
822
|
-
audio = read_audio(noise.name)
|
822
|
+
audio = read_audio(noise.name, self.use_cache)
|
823
823
|
audio = apply_augmentation(audio, mixture.noise.augmentation)
|
824
824
|
if mixture.noise.augmentation.ir is not None:
|
825
825
|
audio = apply_impulse_response(
|
826
826
|
audio,
|
827
|
-
read_ir(self.impulse_response_file(mixture.noise.augmentation.ir)),
|
827
|
+
read_ir(self.impulse_response_file(mixture.noise.augmentation.ir), self.use_cache), # pyright: ignore [reportArgumentType]
|
828
828
|
)
|
829
829
|
|
830
830
|
return audio
|
@@ -1332,7 +1332,7 @@ class MixtureDatabase:
|
|
1332
1332
|
return sorted(set(self.speaker_metadata_tiers + self.textgrid_metadata_tiers))
|
1333
1333
|
|
1334
1334
|
def speaker(self, s_id: int | None, tier: str) -> str | None:
|
1335
|
-
return _speaker(self.db, s_id, tier)
|
1335
|
+
return _speaker(self.db, s_id, tier, self.use_cache)
|
1336
1336
|
|
1337
1337
|
def speech_metadata(self, tier: str) -> list[str]:
|
1338
1338
|
from .helpers import get_textgrid_tier_from_target_file
|
@@ -1464,7 +1464,7 @@ class MixtureDatabase:
|
|
1464
1464
|
|
1465
1465
|
return sorted(result)
|
1466
1466
|
|
1467
|
-
def mixture_metrics(self, m_id: int, metrics: list[str], force: bool = False) ->
|
1467
|
+
def mixture_metrics(self, m_id: int, metrics: list[str], force: bool = False) -> dict[str, Any]:
|
1468
1468
|
"""Get metrics data for the given mixture ID
|
1469
1469
|
|
1470
1470
|
:param m_id: Zero-based mixture ID
|
@@ -1916,21 +1916,34 @@ class MixtureDatabase:
|
|
1916
1916
|
|
1917
1917
|
raise AttributeError(f"Unrecognized metric: '{m}'")
|
1918
1918
|
|
1919
|
-
result:
|
1919
|
+
result: dict[str, Any] = {}
|
1920
1920
|
for metric in metrics:
|
1921
|
-
result
|
1921
|
+
result[metric] = calc(metric)
|
1922
|
+
|
1923
|
+
# Check for metrics dependencies and add them even if not explicitly requested.
|
1924
|
+
if metric.startswith("mxwer"):
|
1925
|
+
dependencies = ("mxasr." + metric[6:], "tasr." + metric[6:])
|
1926
|
+
for dependency in dependencies:
|
1927
|
+
result[dependency] = calc(dependency)
|
1922
1928
|
|
1923
1929
|
return result
|
1924
1930
|
|
1925
1931
|
|
1926
|
-
|
1927
|
-
def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
|
1932
|
+
def _spectral_mask(db: partial, sm_id: int, use_cache: bool = True) -> SpectralMask:
|
1928
1933
|
"""Get spectral mask with ID from db
|
1929
1934
|
|
1930
1935
|
:param db: Database context
|
1931
1936
|
:param sm_id: Spectral mask ID
|
1937
|
+
:param use_cache: If true, use LRU caching
|
1932
1938
|
:return: Spectral mask
|
1933
1939
|
"""
|
1940
|
+
if use_cache:
|
1941
|
+
return __spectral_mask(db, sm_id)
|
1942
|
+
return __spectral_mask.__wrapped__(db, sm_id)
|
1943
|
+
|
1944
|
+
|
1945
|
+
@lru_cache
|
1946
|
+
def __spectral_mask(db: partial, sm_id: int) -> SpectralMask:
|
1934
1947
|
from .db_datatypes import SpectralMaskRecord
|
1935
1948
|
|
1936
1949
|
with db() as c:
|
@@ -1953,12 +1966,26 @@ def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
|
|
1953
1966
|
)
|
1954
1967
|
|
1955
1968
|
|
1969
|
+
def _target_file(db: partial, t_id: int, use_cache: bool = True) -> TargetFile:
|
1970
|
+
"""Get target file with ID from db
|
1971
|
+
|
1972
|
+
:param db: Database context
|
1973
|
+
:param t_id: Target file ID
|
1974
|
+
:param use_cache: If true, use LRU caching
|
1975
|
+
:return: Target file
|
1976
|
+
"""
|
1977
|
+
if use_cache:
|
1978
|
+
return __target_file(db, t_id, use_cache)
|
1979
|
+
return __target_file.__wrapped__(db, t_id, use_cache)
|
1980
|
+
|
1981
|
+
|
1956
1982
|
@lru_cache
|
1957
|
-
def
|
1983
|
+
def __target_file(db: partial, t_id: int, use_cache: bool = True) -> TargetFile:
|
1958
1984
|
"""Get target file with ID from db
|
1959
1985
|
|
1960
1986
|
:param db: Database context
|
1961
1987
|
:param t_id: Target file ID
|
1988
|
+
:param use_cache: If true, use LRU caching
|
1962
1989
|
:return: Target file
|
1963
1990
|
"""
|
1964
1991
|
import json
|
@@ -1982,19 +2009,26 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
|
|
1982
2009
|
samples=target_file.samples,
|
1983
2010
|
class_indices=json.loads(target_file.class_indices),
|
1984
2011
|
level_type=target_file.level_type,
|
1985
|
-
truth_configs=_target_truth_configs(db, t_id),
|
2012
|
+
truth_configs=_target_truth_configs(db, t_id, use_cache),
|
1986
2013
|
speaker_id=target_file.speaker_id,
|
1987
2014
|
)
|
1988
2015
|
|
1989
2016
|
|
1990
|
-
|
1991
|
-
def _noise_file(db: partial, n_id: int) -> NoiseFile:
|
2017
|
+
def _noise_file(db: partial, n_id: int, use_cache: bool = True) -> NoiseFile:
|
1992
2018
|
"""Get noise file with ID from db
|
1993
2019
|
|
1994
2020
|
:param db: Database context
|
1995
2021
|
:param n_id: Noise file ID
|
2022
|
+
:param use_cache: If true, use LRU caching
|
1996
2023
|
:return: Noise file
|
1997
2024
|
"""
|
2025
|
+
if use_cache:
|
2026
|
+
return __noise_file(db, n_id)
|
2027
|
+
return __noise_file.__wrapped__(db, n_id)
|
2028
|
+
|
2029
|
+
|
2030
|
+
@lru_cache
|
2031
|
+
def __noise_file(db: partial, n_id: int) -> NoiseFile:
|
1998
2032
|
with db() as c:
|
1999
2033
|
noise = c.execute(
|
2000
2034
|
"""
|
@@ -2007,14 +2041,21 @@ def _noise_file(db: partial, n_id: int) -> NoiseFile:
|
|
2007
2041
|
return NoiseFile(name=noise[0], samples=noise[1])
|
2008
2042
|
|
2009
2043
|
|
2010
|
-
|
2011
|
-
def _impulse_response_file(db: partial, ir_id: int) -> str:
|
2044
|
+
def _impulse_response_file(db: partial, ir_id: int, use_cache: bool = True) -> str:
|
2012
2045
|
"""Get impulse response file with ID from db
|
2013
2046
|
|
2014
2047
|
:param db: Database context
|
2015
2048
|
:param ir_id: Impulse response file ID
|
2016
|
-
:
|
2049
|
+
:param use_cache: If true, use LRU caching
|
2050
|
+
:return: Impulse response
|
2017
2051
|
"""
|
2052
|
+
if use_cache:
|
2053
|
+
return __impulse_response_file(db, ir_id)
|
2054
|
+
return __impulse_response_file.__wrapped__(db, ir_id)
|
2055
|
+
|
2056
|
+
|
2057
|
+
@lru_cache
|
2058
|
+
def __impulse_response_file(db: partial, ir_id: int) -> str:
|
2018
2059
|
with db() as c:
|
2019
2060
|
return str(
|
2020
2061
|
c.execute(
|
@@ -2028,14 +2069,21 @@ def _impulse_response_file(db: partial, ir_id: int) -> str:
|
|
2028
2069
|
)
|
2029
2070
|
|
2030
2071
|
|
2031
|
-
|
2032
|
-
def _mixture(db: partial, m_id: int) -> Mixture:
|
2072
|
+
def _mixture(db: partial, m_id: int, use_cache: bool = True) -> Mixture:
|
2033
2073
|
"""Get mixture record with ID from db
|
2034
2074
|
|
2035
2075
|
:param db: Database context
|
2036
2076
|
:param m_id: Zero-based mixture ID
|
2077
|
+
:param use_cache: If true, use LRU caching
|
2037
2078
|
:return: Mixture record
|
2038
2079
|
"""
|
2080
|
+
if use_cache:
|
2081
|
+
return __mixture(db, m_id)
|
2082
|
+
return __mixture.__wrapped__(db, m_id)
|
2083
|
+
|
2084
|
+
|
2085
|
+
@lru_cache
|
2086
|
+
def __mixture(db: partial, m_id: int) -> Mixture:
|
2039
2087
|
from .db_datatypes import MixtureRecord
|
2040
2088
|
from .db_datatypes import TargetRecord
|
2041
2089
|
from .helpers import to_mixture
|
@@ -2068,8 +2116,14 @@ def _mixture(db: partial, m_id: int) -> Mixture:
|
|
2068
2116
|
return to_mixture(mixture, targets)
|
2069
2117
|
|
2070
2118
|
|
2119
|
+
def _speaker(db: partial, s_id: int | None, tier: str, use_cache: bool = True) -> str | None:
|
2120
|
+
if use_cache:
|
2121
|
+
return __speaker(db, s_id, tier)
|
2122
|
+
return __speaker.__wrapped__(db, s_id, tier)
|
2123
|
+
|
2124
|
+
|
2071
2125
|
@lru_cache
|
2072
|
-
def
|
2126
|
+
def __speaker(db: partial, s_id: int | None, tier: str) -> str | None:
|
2073
2127
|
if s_id is None:
|
2074
2128
|
return None
|
2075
2129
|
|
@@ -2082,8 +2136,14 @@ def _speaker(db: partial, s_id: int | None, tier: str) -> str | None:
|
|
2082
2136
|
return data[0]
|
2083
2137
|
|
2084
2138
|
|
2139
|
+
def _target_truth_configs(db: partial, t_id: int, use_cache: bool = True) -> TruthConfigs:
|
2140
|
+
if use_cache:
|
2141
|
+
return __target_truth_configs(db, t_id)
|
2142
|
+
return __target_truth_configs.__wrapped__(db, t_id)
|
2143
|
+
|
2144
|
+
|
2085
2145
|
@lru_cache
|
2086
|
-
def
|
2146
|
+
def __target_truth_configs(db: partial, t_id: int) -> TruthConfigs:
|
2087
2147
|
import json
|
2088
2148
|
|
2089
2149
|
from .datatypes import TruthConfig
|
@@ -4,10 +4,16 @@ from sonusai.mixture.datatypes import AudioT
|
|
4
4
|
from sonusai.mixture.datatypes import ImpulseResponseData
|
5
5
|
|
6
6
|
|
7
|
-
def read_impulse_response(
|
7
|
+
def read_impulse_response(
|
8
|
+
name: str | Path,
|
9
|
+
delay_compensation: bool = True,
|
10
|
+
normalize: bool = True,
|
11
|
+
) -> ImpulseResponseData:
|
8
12
|
"""Read impulse response data using torchaudio
|
9
13
|
|
10
14
|
:param name: File name
|
15
|
+
:param delay_compensation: Apply delay compensation
|
16
|
+
:param normalize: Apply normalization
|
11
17
|
:return: ImpulseResponseData object
|
12
18
|
"""
|
13
19
|
import numpy as np
|
@@ -28,14 +34,19 @@ def read_impulse_response(name: str | Path) -> ImpulseResponseData:
|
|
28
34
|
raise OSError(f"Error reading {name}: {e}") from e
|
29
35
|
|
30
36
|
raw = torch.squeeze(raw[0, :])
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
# raw = raw / torch.linalg.vector_norm(raw)
|
37
|
+
|
38
|
+
if delay_compensation:
|
39
|
+
offset = torch.argmax(raw)
|
40
|
+
raw = raw[offset:]
|
36
41
|
|
37
42
|
data = np.array(raw).astype(np.float32)
|
38
|
-
|
43
|
+
|
44
|
+
if normalize:
|
45
|
+
# Inexplicably,
|
46
|
+
# data = data / torch.linalg.vector_norm(data)
|
47
|
+
# causes multiprocessing contexts to hang.
|
48
|
+
# Use np.linalg.norm() instead.
|
49
|
+
data = data / np.linalg.norm(data)
|
39
50
|
|
40
51
|
return ImpulseResponseData(name=str(name), sample_rate=sample_rate, data=data)
|
41
52
|
|
@@ -20,10 +20,9 @@ def apply_augmentation(audio: AudioT, augmentation: Augmentation, frame_length:
|
|
20
20
|
|
21
21
|
effects: list[list[str]] = []
|
22
22
|
|
23
|
-
# TODO
|
24
|
-
#
|
25
|
-
#
|
26
|
-
# or hard-coded into the script?)
|
23
|
+
# TODO: Always normalize and remove normalize from list of available augmentations
|
24
|
+
# Normalize to globally set level (should this be a global config parameter, or hard-coded into the script?)
|
25
|
+
# TODO: Support all sox effects supported by torchaudio (torchaudio.sox_effects.effect_names())
|
27
26
|
if augmentation.normalize is not None:
|
28
27
|
effects.append(["norm", str(augmentation.normalize)])
|
29
28
|
|
@@ -13,6 +13,7 @@ def _core(
|
|
13
13
|
parameters: int,
|
14
14
|
mapped: bool,
|
15
15
|
snr: bool,
|
16
|
+
use_cache: bool = True,
|
16
17
|
) -> Truth:
|
17
18
|
from os.path import join
|
18
19
|
|
@@ -50,8 +51,8 @@ def _core(
|
|
50
51
|
tmp = np.nan_to_num(tmp, nan=-np.inf, posinf=np.inf, neginf=-np.inf)
|
51
52
|
|
52
53
|
if mapped:
|
53
|
-
snr_db_mean = load_object(join(mixdb.location, config["snr_db_mean"]))
|
54
|
-
snr_db_std = load_object(join(mixdb.location, config["snr_db_std"]))
|
54
|
+
snr_db_mean = load_object(join(mixdb.location, config["snr_db_mean"]), use_cache)
|
55
|
+
snr_db_std = load_object(join(mixdb.location, config["snr_db_std"]), use_cache)
|
55
56
|
tmp = _calculate_mapped_snr_f(tmp, snr_db_mean, snr_db_std)
|
56
57
|
|
57
58
|
truth[frame] = tmp
|
@@ -85,7 +86,7 @@ def energy_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
|
85
86
|
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
86
87
|
|
87
88
|
|
88
|
-
def energy_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
|
89
|
+
def energy_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict, use_cache: bool = True) -> Truth:
|
89
90
|
"""Frequency domain energy truth generation function
|
90
91
|
|
91
92
|
Calculates the true energy per bin:
|
@@ -104,6 +105,7 @@ def energy_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict)
|
|
104
105
|
parameters=energy_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
105
106
|
mapped=False,
|
106
107
|
snr=False,
|
108
|
+
use_cache=use_cache,
|
107
109
|
)
|
108
110
|
|
109
111
|
|
@@ -118,7 +120,7 @@ def snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
|
118
120
|
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
119
121
|
|
120
122
|
|
121
|
-
def snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
|
123
|
+
def snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict, use_cache: bool = True) -> Truth:
|
122
124
|
"""Frequency domain SNR truth function documentation
|
123
125
|
|
124
126
|
Calculates the true SNR per bin:
|
@@ -137,6 +139,7 @@ def snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) ->
|
|
137
139
|
parameters=snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
138
140
|
mapped=False,
|
139
141
|
snr=True,
|
142
|
+
use_cache=use_cache,
|
140
143
|
)
|
141
144
|
|
142
145
|
|
@@ -156,7 +159,7 @@ def mapped_snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> i
|
|
156
159
|
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
157
160
|
|
158
161
|
|
159
|
-
def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
|
162
|
+
def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict, use_cache: bool = True) -> Truth:
|
160
163
|
"""Frequency domain mapped SNR truth function documentation
|
161
164
|
|
162
165
|
Output shape: [:, bins]
|
@@ -169,6 +172,7 @@ def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: d
|
|
169
172
|
parameters=mapped_snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
170
173
|
mapped=True,
|
171
174
|
snr=True,
|
175
|
+
use_cache=use_cache,
|
172
176
|
)
|
173
177
|
|
174
178
|
|
sonusai/mkwav.py
CHANGED
@@ -63,7 +63,7 @@ def _process_mixture(m_id: int, location: str, write_target: bool, write_targets
|
|
63
63
|
if write_noise:
|
64
64
|
write_audio(name=join(location, "noise.wav"), audio=float_to_int16(mixdb.mixture_noise(m_id)))
|
65
65
|
|
66
|
-
write_mixture_metadata(mixdb, m_id)
|
66
|
+
write_mixture_metadata(mixdb, m_id=m_id)
|
67
67
|
|
68
68
|
|
69
69
|
def main() -> None:
|
sonusai/onnx_predict.py
CHANGED
@@ -193,7 +193,7 @@ def main() -> None:
|
|
193
193
|
# run inference, ort session wants i.e. batch x timesteps x feat_params, outputs numpy BxTxFP or BxFP
|
194
194
|
predict = session.run(out_names, {in0name: feature})[0]
|
195
195
|
# predict, _ = reshape_outputs(predict=predict[0], timesteps=frames) # frames x feat_params
|
196
|
-
output_fname = join(output_dir, mixdb.
|
196
|
+
output_fname = join(output_dir, mixdb.mixture(mixid).name)
|
197
197
|
with h5py.File(output_fname, "a") as f:
|
198
198
|
if "predict" in f:
|
199
199
|
del f["predict"]
|
sonusai/queries/queries.py
CHANGED
@@ -178,7 +178,7 @@ def get_mixids_from_snr(
|
|
178
178
|
result: dict[float, list[int]] = {}
|
179
179
|
for snr in snrs:
|
180
180
|
# Get a list of mixids for each SNR
|
181
|
-
result[snr] = sorted([i for i, mixture in enumerate(mixdb.mixtures) if mixture.snr == snr and i in mixid_out])
|
181
|
+
result[snr] = sorted([i for i, mixture in enumerate(mixdb.mixtures()) if mixture.snr == snr and i in mixid_out])
|
182
182
|
|
183
183
|
return result
|
184
184
|
|
sonusai/utils/asr.py
CHANGED
@@ -65,7 +65,7 @@ def calc_asr(audio: AudioT | str, engine: str, **config) -> ASRResult:
|
|
65
65
|
from sonusai.mixture import read_audio
|
66
66
|
|
67
67
|
if not isinstance(audio, np.ndarray):
|
68
|
-
audio = copy(read_audio(audio))
|
68
|
+
audio = copy(read_audio(audio, config.get("use_cache", True)))
|
69
69
|
|
70
70
|
return _asr_fn(engine)(audio, **config)
|
71
71
|
|
sonusai/utils/load_object.py
CHANGED
@@ -2,9 +2,15 @@ from functools import lru_cache
|
|
2
2
|
from typing import Any
|
3
3
|
|
4
4
|
|
5
|
+
def load_object(name: str, use_cache: bool = True) -> Any:
|
6
|
+
"""Load an object from a pickle file"""
|
7
|
+
if use_cache:
|
8
|
+
return _load_object(name)
|
9
|
+
return _load_object.__wrapped__(name)
|
10
|
+
|
11
|
+
|
5
12
|
@lru_cache
|
6
|
-
def
|
7
|
-
"""Load an object from a pickle file (with LRU caching)"""
|
13
|
+
def _load_object(name: str) -> Any:
|
8
14
|
import pickle
|
9
15
|
from os.path import exists
|
10
16
|
|
@@ -42,7 +42,7 @@ def stratified_shuffle_split_mixid(
|
|
42
42
|
raise ValueError("vsplit must be between 0 and 1")
|
43
43
|
|
44
44
|
a_class_mixid: dict[int, list[int]] = {i + 1: [] for i in range(mixdb.num_classes)}
|
45
|
-
for mixid, mixture in enumerate(mixdb.mixtures):
|
45
|
+
for mixid, mixture in enumerate(mixdb.mixtures()):
|
46
46
|
class_count = get_class_count_from_mixids(mixdb, mixid)
|
47
47
|
if any(class_count):
|
48
48
|
for class_index in mixdb.target_files[mixture.targets[0].file_id].class_indices:
|
@@ -1,9 +1,11 @@
|
|
1
1
|
sonusai/__init__.py,sha256=NSb0bvmAh6Rm2MDtchpAGsg8a3BrmVnShYb-vC_emH8,2802
|
2
2
|
sonusai/aawscd_probwrite.py,sha256=QZLMQrmPr3OjZ06buyYDwlnk9YPCpyr4KHkBjPsiqjU,3700
|
3
3
|
sonusai/audiofe.py,sha256=iFdthh4UrOvziT8urjrjD7dACWZPQz9orM5bVAW3WSQ,11269
|
4
|
-
sonusai/calc_metric_spenh.py,sha256=
|
4
|
+
sonusai/calc_metric_spenh.py,sha256=XWa2DzLSCEQ6GzsJv-YHfnN51f_oFwcRMMgMzusAvYA,49304
|
5
5
|
sonusai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
6
|
sonusai/data/genmixdb.yml,sha256=U_kLbE7gZ5rA7yNSB2NW7eK5dnYP5grJVMR321VMLt8,940
|
7
|
+
sonusai/data/silero_vad_v5.1.jit,sha256=hcSOHw7LYE5dKiaPPM-5EtT36TWs3IavWj_FsK6nspo,2269612
|
8
|
+
sonusai/data/silero_vad_v5.1.onnx,sha256=JiOilT9v89LB5hdAxs23FoEzR5smff7xFKSjzFvdeI8,2327524
|
7
9
|
sonusai/data/speech_ma01_01.wav,sha256=PK0vMKg-NR6rPE3KouxHGF6PKXnJCr7AwjMqfu98LUA,76644
|
8
10
|
sonusai/data/whitenoise.wav,sha256=I2umov0m34y56F9IsIBi1XtE76ZeZaSKDf70cJRe3pI,1920044
|
9
11
|
sonusai/deprecated/gentcst.py,sha256=nKbHy3aHreHqA-XnLQOzOApS8RuTNUFqnx52a8I5zLQ,19921
|
@@ -12,10 +14,10 @@ sonusai/deprecated/tplot.py,sha256=0p238DvTaP4oU9y-dp0JdLaTV4TKrooAwbx7zdz_QAc,1
|
|
12
14
|
sonusai/doc/__init__.py,sha256=KyQ26Um0RM8A3GYsb_tbFH64RwpoAw6lja2f_moUWas,33
|
13
15
|
sonusai/doc/doc.py,sha256=VZXauwbOb-VIufWw-lu0yfrd6jMRPeFeVPaaEjZNvn4,18881
|
14
16
|
sonusai/doc.py,sha256=zSmXpioB0YS_5-7kqfS5cr--veSaXkxRKzldId9Hyoc,878
|
15
|
-
sonusai/genft.py,sha256=
|
16
|
-
sonusai/genmetrics.py,sha256=
|
17
|
-
sonusai/genmix.py,sha256=
|
18
|
-
sonusai/genmixdb.py,sha256=
|
17
|
+
sonusai/genft.py,sha256=K2wjO5J48UgyhCj2Sx789nkjt0DWtYgnRDbQyNtjCSY,5591
|
18
|
+
sonusai/genmetrics.py,sha256=jORQCdf_SCrtcvDd47lgcPgQTplG956RTAqmf58Xe8Y,5689
|
19
|
+
sonusai/genmix.py,sha256=mSc5FfAYrUt3zloPSnp81dks8ntvSH6jyk-nh97wnww,6707
|
20
|
+
sonusai/genmixdb.py,sha256=SsbHRpPoJ77XzOBQRRDheucyuJzE-tucQtRoYl89ApU,17841
|
19
21
|
sonusai/lsdb.py,sha256=0HOGDDndB3LT9cz9AaxKIpt9vslAoSP4F239gply4Xg,5149
|
20
22
|
sonusai/main.py,sha256=HbnEia1B1-Z-mlHkLfojH8aj9GIpL1Btw3oH60T_CCQ,2590
|
21
23
|
sonusai/metrics/__init__.py,sha256=ssV6JEK_oklRSocsp6HMcG-GtJvV8IkRQtdKhHHmwU8,878
|
@@ -35,8 +37,9 @@ sonusai/metrics/class_summary.py,sha256=ZA7zNgwBpmTs1TP_t4jRT0pWnDnATC_up_8qE4aH
|
|
35
37
|
sonusai/metrics/confusion_matrix_summary.py,sha256=zBL_Ke7wF6oKtrKZPr0fsyF_taofdjxBlZmKodu0xUA,3143
|
36
38
|
sonusai/metrics/one_hot.py,sha256=hmuyh-9tpRjb_oyqU3WqZ14zItpRJQfcqBDKJeb5H9I,13930
|
37
39
|
sonusai/metrics/snr_summary.py,sha256=t8Fi_8WtboTi8flkZuOiHq9H3-nIELx4AKvnm-qvxLQ,5785
|
38
|
-
sonusai/
|
39
|
-
sonusai/mixture/
|
40
|
+
sonusai/metrics_summary.py,sha256=HVqjgCavxM1yzyoeDZSg_bJaXrifNQxNY7xYNKKva8g,12004
|
41
|
+
sonusai/mixture/__init__.py,sha256=ePkmFbBltwHsx1eJDb_RDieTceZtqa1wVY1D2Pfg2rw,5162
|
42
|
+
sonusai/mixture/audio.py,sha256=5iq39_Q0q9xuN_FNylvnn-gAZ8Io3Ir1Mqj60mVQeaQ,3432
|
40
43
|
sonusai/mixture/augmentation.py,sha256=s8QlPHnFJOblRU59fMQ-Zqysiv4OUJ7CxLRcV81lnaA,10407
|
41
44
|
sonusai/mixture/class_count.py,sha256=zcC3BDYMPN6wJYmO1RcOuqmrnTQIbMSznl33oN3e2sc,597
|
42
45
|
sonusai/mixture/config.py,sha256=g5ZmOhFYqmEdRQYSgfDIZ9VM0QiTwBqk7vIyAvxnPMo,24211
|
@@ -46,10 +49,10 @@ sonusai/mixture/datatypes.py,sha256=xNDBWFTVQ3plJ7qHKzrXyV4pffPYuf1xMVqBsR40n4o,
|
|
46
49
|
sonusai/mixture/db_datatypes.py,sha256=kvdUOMS6Pkkj9AmxCiq6zM8x7jbPPi933tVaXRxbTdQ,1534
|
47
50
|
sonusai/mixture/eq_rule_is_valid.py,sha256=O3gCAs_0hpxENK5b7kxxpDmOpKHlXGBWuLGT_97ARSM,1210
|
48
51
|
sonusai/mixture/feature.py,sha256=L0bPFG0RO-CrrtTStUMt_14euYsVo8_TWTP2IKSFKaA,2335
|
49
|
-
sonusai/mixture/generation.py,sha256=
|
50
|
-
sonusai/mixture/helpers.py,sha256=
|
52
|
+
sonusai/mixture/generation.py,sha256=yoJOcY9KPe_B1RVnENVr4ekcnXyZJMdvKMbJggpLOi4,38084
|
53
|
+
sonusai/mixture/helpers.py,sha256=Bt9njNb_OZ3j02qgrVEMZiL0hX4kXtFK_tkPoGoeb4Y,15787
|
51
54
|
sonusai/mixture/log_duration_and_sizes.py,sha256=qhgl87C2KbjxLdKEpjYOoqNL6rc-8-PB4R7Gx_7UG8g,1240
|
52
|
-
sonusai/mixture/mixdb.py,sha256=
|
55
|
+
sonusai/mixture/mixdb.py,sha256=Yg3FQqb6oI3LsFh_00CvMeH1Rrmn2pA5waaAyJDCpfY,75912
|
53
56
|
sonusai/mixture/soundfile_audio.py,sha256=At_ZC2b9pZ_9IYp1UxyPzRoBK9-1cKPCLMm74F1AjKE,4092
|
54
57
|
sonusai/mixture/sox_audio.py,sha256=7ouCLqXYS6tjG2L0v5lugVO7z5UwJmsr1VigbrXhs74,16725
|
55
58
|
sonusai/mixture/sox_augmentation.py,sha256=DtfGLPaB1BIt2wvTEA__MYkGFNU85Tuup5BFsIVrh0E,4546
|
@@ -57,22 +60,22 @@ sonusai/mixture/spectral_mask.py,sha256=U9XJ_SAoI9b67K_3SE7bNw6U8cPGFOBttaZAxMjA
|
|
57
60
|
sonusai/mixture/target_class_balancing.py,sha256=o_TZ8kVYq10lgeXHh3GUFfflfdUvRt0FekFu2eaNkDs,4251
|
58
61
|
sonusai/mixture/targets.py,sha256=6emo2fxxp9ZhSpHuUM9xIjYMz8zeIHAw684jT3l7fAs,6442
|
59
62
|
sonusai/mixture/tokenized_shell_vars.py,sha256=lXTzUDutuBWGV1zIsqeIxWmy-eKm0Vx1y8-iLdsL1gQ,4921
|
60
|
-
sonusai/mixture/torchaudio_audio.py,sha256=
|
61
|
-
sonusai/mixture/torchaudio_augmentation.py,sha256=
|
63
|
+
sonusai/mixture/torchaudio_audio.py,sha256=72Hxo5TKAW7mYpRy15QFfD7AYDORBk6bVCcHENniWGw,3116
|
64
|
+
sonusai/mixture/torchaudio_augmentation.py,sha256=uFAKxIfs50J5FR-WXodsEACm2Ao-t5dZRSJ0DwTAfBg,3930
|
62
65
|
sonusai/mixture/truth.py,sha256=-CwwawFRGjqodR2yKvAMGL1XaYLct-tli7wZ2gbhLtQ,2121
|
63
66
|
sonusai/mixture/truth_functions/__init__.py,sha256=0mlOFChPnXG5BC0eKOe4n9VH17jY4iOqZFLuF6Gprdk,1505
|
64
67
|
sonusai/mixture/truth_functions/crm.py,sha256=iidcffXfqV8k9O5wt5KTWIAFaTSjmhV5ucKZPbTgpvQ,3809
|
65
|
-
sonusai/mixture/truth_functions/energy.py,sha256=
|
68
|
+
sonusai/mixture/truth_functions/energy.py,sha256=BMpyFoFDRsKEv3ZxZAJPLgMgkBkA6AtGBg3MjRu1do8,6749
|
66
69
|
sonusai/mixture/truth_functions/file.py,sha256=pyCAhx3PhJRBoZMrjoQI4Tbi5TN7sPembSVEr80Bu3g,1431
|
67
70
|
sonusai/mixture/truth_functions/metadata.py,sha256=aEZly5bJEaZpUBZonWvcu14_Dn3M2HamwTaM5Bg7Tm8,778
|
68
|
-
sonusai/mixture/truth_functions/metrics.py,sha256=
|
71
|
+
sonusai/mixture/truth_functions/metrics.py,sha256=AzQjKJ7rihk_UXOz0Atyktpzo2g9ZPMZVZPCESBIoao,876
|
69
72
|
sonusai/mixture/truth_functions/phoneme.py,sha256=jwBYiNwwBwh2tHtOJ2NopYWhT6y19kXzSIag0XW9GSY,778
|
70
73
|
sonusai/mixture/truth_functions/sed.py,sha256=C0n9DkfBNQblFsFCkPbooy54KuHSY7B0f1vLft2asdw,3832
|
71
74
|
sonusai/mixture/truth_functions/target.py,sha256=nSkHFESzCEOljcYf4jQ7FmxsAWJtMCRRWFKM_DyjoLU,4926
|
72
|
-
sonusai/mkwav.py,sha256=
|
73
|
-
sonusai/onnx_predict.py,sha256=
|
75
|
+
sonusai/mkwav.py,sha256=ElivON2G_BT_ffKnePmPoeydl0g2DLGrbIFxfn_I1XI,4058
|
76
|
+
sonusai/onnx_predict.py,sha256=T97ceb9stR_QtJCA-Rmv67OIeaLdyhyCf1jQ9kVOYn8,8698
|
74
77
|
sonusai/queries/__init__.py,sha256=bhoeOFfu9GA5DOUuxRrIev7MYdXaGN8xdKJ6BXyNNtQ,277
|
75
|
-
sonusai/queries/queries.py,sha256=
|
78
|
+
sonusai/queries/queries.py,sha256=srcEYBqLJhjqyfuJ-FwNkUwpjxYiNQeybxL3eQGm2nw,7511
|
76
79
|
sonusai/speech/__init__.py,sha256=vqAymCBPjMUSM4OZKHTai6BYwXsOBlf_G_vOhELVf8I,133
|
77
80
|
sonusai/speech/l2arctic.py,sha256=VQNKuTbmlbW0PJ7bOjx9sr0VjUYxJnxfTiPJIa4OOaA,3829
|
78
81
|
sonusai/speech/librispeech.py,sha256=ugP3NVOenSsBF1cUG4Nyl7dumGHQmE4Ugk1yYjtOyj4,3070
|
@@ -85,7 +88,7 @@ sonusai/speech/voxceleb.py,sha256=Uu1kB1krf8hess1yuvGbYfV_VgYhklEyoz4I7KfrVpw,26
|
|
85
88
|
sonusai/summarize_metric_spenh.py,sha256=2w81ZgJahYvD6wCpE3DFoUFrXexLXjO44ITRVm1HJXw,1858
|
86
89
|
sonusai/utils/__init__.py,sha256=z72OlzZCHpYfYHKnHn7jznj6Zt7zB-FyO6hIgFk45As,2379
|
87
90
|
sonusai/utils/asl_p56.py,sha256=cPUVwXawF7vLJgs4zUtoRGk7Wdbe5KKti_-v_8xIU10,3862
|
88
|
-
sonusai/utils/asr.py,sha256=
|
91
|
+
sonusai/utils/asr.py,sha256=ubiU3E61HN3r9MhPV7ci37cnLZowll8KfjUS7os3Sho,2822
|
89
92
|
sonusai/utils/asr_functions/__init__.py,sha256=HKGRm_c48tcxlfwqH63m-MvhAoK_pCcw76lxmFmiP_U,63
|
90
93
|
sonusai/utils/asr_functions/aaware_whisper.py,sha256=M9Y8Pgh1oIrDOPZZPSRPDig8foxfgs3f8AsoZ8W00B0,2120
|
91
94
|
sonusai/utils/audio_devices.py,sha256=_Eiah86SZjbdp2baD2AUVF4FmhseiNuG3KJkd_LbULk,2041
|
@@ -104,7 +107,7 @@ sonusai/utils/get_frames_per_batch.py,sha256=xnq4tV7MT74N0H6b5ZsiAezqdXucboCLQw1
|
|
104
107
|
sonusai/utils/get_label_names.py,sha256=df4jZVaQ3WnYQqNj21iUV4aYWyQEZUNmgs93qKW-_rA,820
|
105
108
|
sonusai/utils/grouper.py,sha256=qyZ0nj84yOrC-RZsXHC-KJvcUliGktnV8S6-P3PD6_w,203
|
106
109
|
sonusai/utils/human_readable_size.py,sha256=DOCS7SAymrtTZli8AczvyCMCh44r7ZDgVBA7jSZupmA,356
|
107
|
-
sonusai/utils/load_object.py,sha256
|
110
|
+
sonusai/utils/load_object.py,sha256=if4Vammcd-jZTz_n7QzwNIlN4HqSL0v91I9YQzcvEEA,493
|
108
111
|
sonusai/utils/max_text_width.py,sha256=pxiJMwb_zlkNntexgo7S6lAuF7NLLZvFdOCkxdsQJVY,315
|
109
112
|
sonusai/utils/model_utils.py,sha256=OIJBhOjxR0wpxsd7A2r6J2AjqfdYgZzi6UEThw4S1lI,828
|
110
113
|
sonusai/utils/numeric_conversion.py,sha256=iFPXFU8C_1mW5tmDqHq8-xP1tL8nVaSmhQRakdCqy30,328
|
@@ -117,11 +120,11 @@ sonusai/utils/read_predict_data.py,sha256=PUSroxmWQGtr6_EcdSHmIFQoRGou8CKKqcggWy
|
|
117
120
|
sonusai/utils/reshape.py,sha256=Ozuh3UlmAS5NCeOK7NR8KgcQacHvgq10pys0VfCnOPU,5746
|
118
121
|
sonusai/utils/seconds_to_hms.py,sha256=9Ya9O97txFtTIXZUQw1K8g7b7Xx-ptvUtMUlzsIduTo,260
|
119
122
|
sonusai/utils/stacked_complex.py,sha256=JW6iAa1C-4Tuh4dD5c-D-O-yo-OY5Xm0AKVU0YsqsJU,2782
|
120
|
-
sonusai/utils/stratified_shuffle_split.py,sha256=
|
123
|
+
sonusai/utils/stratified_shuffle_split.py,sha256=fcGW8nkZIwUqq1qtxbK_ZH58sYULqZfv7iNBQnKGH-M,6706
|
121
124
|
sonusai/utils/write_audio.py,sha256=0lKdaX57N6H-UWdioqmXCJMjwT1eBz5B-bSGqDvloAc,838
|
122
125
|
sonusai/utils/yes_or_no.py,sha256=0h1okjXmDNbJp7rZJFR2V-HFU1GJDm3YFTUVmYExkOU,263
|
123
126
|
sonusai/vars.py,sha256=kBBzuvC8szmdIZEEDA7XXmD765addZKdM2aFipeGO1w,933
|
124
|
-
sonusai-0.19.
|
125
|
-
sonusai-0.19.
|
126
|
-
sonusai-0.19.
|
127
|
-
sonusai-0.19.
|
127
|
+
sonusai-0.19.10.dist-info/METADATA,sha256=ibwwklSb5-vmwAJMdRhW0MBWxqQYFVsYpEx5-8oaRXI,2536
|
128
|
+
sonusai-0.19.10.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
129
|
+
sonusai-0.19.10.dist-info/entry_points.txt,sha256=zMNjEphEPO6B3cD1GNpit7z-yA9tUU5-j3W2v-UWstU,92
|
130
|
+
sonusai-0.19.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|