sonusai 0.18.9__py3-none-any.whl → 0.19.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +81 -91
- sonusai/genmetrics.py +51 -61
- sonusai/genmix.py +105 -115
- sonusai/genmixdb.py +201 -174
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +16 -18
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +20 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +40 -38
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +669 -477
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +58 -101
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +41 -30
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/METADATA +20 -21
- sonusai-0.19.6.dist-info/RECORD +125 -0
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.9.dist-info/RECORD +0 -125
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/entry_points.txt +0 -0
sonusai/mixture/truth.py
CHANGED
@@ -1,50 +1,72 @@
|
|
1
1
|
from sonusai.mixture.datatypes import AudioT
|
2
2
|
from sonusai.mixture.datatypes import Truth
|
3
|
-
from sonusai.mixture.datatypes import
|
3
|
+
from sonusai.mixture.datatypes import TruthConfig
|
4
4
|
from sonusai.mixture.mixdb import MixtureDatabase
|
5
5
|
|
6
6
|
|
7
|
-
def truth_function(
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
7
|
+
def truth_function(
|
8
|
+
target_audio: AudioT,
|
9
|
+
noise_audio: AudioT,
|
10
|
+
mixture_audio: AudioT,
|
11
|
+
config: TruthConfig,
|
12
|
+
feature: str,
|
13
|
+
num_classes: int,
|
14
|
+
class_indices: list[int],
|
15
|
+
target_gain: float,
|
16
|
+
) -> Truth:
|
12
17
|
from sonusai.mixture import truth_functions
|
13
|
-
from .truth_functions.data import Data
|
14
18
|
|
15
|
-
|
16
|
-
|
17
|
-
|
19
|
+
from .truth_functions.datatypes import TruthFunctionConfig
|
20
|
+
from .truth_functions.datatypes import TruthFunctionData
|
21
|
+
|
22
|
+
t_config = TruthFunctionConfig(
|
23
|
+
feature=feature,
|
24
|
+
num_classes=num_classes,
|
25
|
+
class_indices=class_indices,
|
26
|
+
target_gain=target_gain,
|
27
|
+
config=config.config,
|
28
|
+
)
|
29
|
+
t_data = TruthFunctionData(target_audio, noise_audio, mixture_audio)
|
18
30
|
|
19
31
|
try:
|
20
|
-
return getattr(truth_functions,
|
21
|
-
except AttributeError:
|
22
|
-
raise
|
32
|
+
return getattr(truth_functions, config.function)(t_data, t_config)
|
33
|
+
except AttributeError as e:
|
34
|
+
raise AttributeError(f"Unsupported truth function: {config.function}") from e
|
35
|
+
except Exception as e:
|
36
|
+
raise RuntimeError(f"Error in truth function '{config.function}': {e}") from e
|
23
37
|
|
24
38
|
|
25
39
|
def get_truth_indices_for_mixid(mixdb: MixtureDatabase, mixid: int) -> list[int]:
|
26
40
|
"""Get a list of truth indices for a given mixid."""
|
27
|
-
from .targets import get_truth_indices_for_target
|
28
|
-
|
29
41
|
indices: list[int] = []
|
30
42
|
for target_id in [target.file_id for target in mixdb.mixture(mixid).targets]:
|
31
|
-
indices.append(*
|
43
|
+
indices.append(*mixdb.target_file(target_id).class_indices)
|
32
44
|
|
33
|
-
return sorted(
|
45
|
+
return sorted(set(indices))
|
34
46
|
|
35
47
|
|
36
|
-
def
|
48
|
+
def truth_stride_reduction(truth: Truth, function: str) -> Truth:
|
49
|
+
"""Reduce stride dimension of truth.
|
50
|
+
|
51
|
+
:param truth: Truth data [frames, stride, truth_parameters]
|
52
|
+
:param function: Truth stride reduction function name
|
53
|
+
:return: Stride reduced truth data [frames, stride or 1, truth_parameters]
|
54
|
+
"""
|
37
55
|
import numpy as np
|
38
56
|
|
39
|
-
|
57
|
+
if truth.ndim != 3:
|
58
|
+
raise ValueError("Invalid truth shape")
|
59
|
+
|
60
|
+
if function == "none":
|
61
|
+
return truth
|
40
62
|
|
41
|
-
if
|
42
|
-
return np.max(
|
63
|
+
if function == "max":
|
64
|
+
return np.max(truth, axis=1, keepdims=True)
|
43
65
|
|
44
|
-
if
|
45
|
-
return np.mean(
|
66
|
+
if function == "mean":
|
67
|
+
return np.mean(truth, axis=1, keepdims=True)
|
46
68
|
|
47
|
-
if
|
48
|
-
return
|
69
|
+
if function == "first":
|
70
|
+
return truth[:, 0, :].reshape((truth.shape[0], 1, truth.shape[2]))
|
49
71
|
|
50
|
-
raise
|
72
|
+
raise ValueError(f"Invalid truth stride reduction function: {function}")
|
@@ -1,13 +1,39 @@
|
|
1
1
|
# SonusAI truth functions
|
2
|
+
# ruff: noqa: F401
|
3
|
+
|
2
4
|
from .crm import crm
|
5
|
+
from .crm import crm_parameters
|
6
|
+
from .crm import crm_validate
|
3
7
|
from .crm import crmp
|
8
|
+
from .crm import crmp_parameters
|
9
|
+
from .crm import crmp_validate
|
4
10
|
from .energy import energy_f
|
11
|
+
from .energy import energy_f_parameters
|
12
|
+
from .energy import energy_f_validate
|
5
13
|
from .energy import energy_t
|
14
|
+
from .energy import energy_t_parameters
|
15
|
+
from .energy import energy_t_validate
|
6
16
|
from .energy import mapped_snr_f
|
17
|
+
from .energy import mapped_snr_f_parameters
|
18
|
+
from .energy import mapped_snr_f_validate
|
7
19
|
from .energy import snr_f
|
20
|
+
from .energy import snr_f_parameters
|
21
|
+
from .energy import snr_f_validate
|
8
22
|
from .file import file
|
23
|
+
from .file import file_parameters
|
24
|
+
from .file import file_validate
|
9
25
|
from .phoneme import phoneme
|
26
|
+
from .phoneme import phoneme_parameters
|
27
|
+
from .phoneme import phoneme_validate
|
10
28
|
from .sed import sed
|
29
|
+
from .sed import sed_parameters
|
30
|
+
from .sed import sed_validate
|
11
31
|
from .target import target_f
|
32
|
+
from .target import target_f_parameters
|
33
|
+
from .target import target_f_validate
|
12
34
|
from .target import target_mixture_f
|
35
|
+
from .target import target_mixture_f_parameters
|
36
|
+
from .target import target_mixture_f_validate
|
13
37
|
from .target import target_swin_f
|
38
|
+
from .target import target_swin_f_parameters
|
39
|
+
from .target import target_swin_f_validate
|
@@ -1,25 +1,26 @@
|
|
1
1
|
from sonusai.mixture.datatypes import Truth
|
2
|
-
from sonusai.mixture.truth_functions.
|
2
|
+
from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
|
3
|
+
from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
|
3
4
|
|
4
5
|
|
5
|
-
def _core(data:
|
6
|
+
def _core(data: TruthFunctionData, config: TruthFunctionConfig, polar: bool) -> Truth:
|
6
7
|
import numpy as np
|
7
8
|
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
noise_f =
|
9
|
+
if config.target_fft.bins != config.noise_fft.bins:
|
10
|
+
raise ValueError("Transform size mismatch for crm truth")
|
11
|
+
|
12
|
+
frames = len(data.target_audio) // config.frame_size
|
13
|
+
truth = np.empty((frames, config.target_fft.bins * 2), dtype=np.float32)
|
14
|
+
for frame in range(frames):
|
15
|
+
offset = frame * config.frame_size
|
16
|
+
target_f = config.target_fft.execute(data.target_audio[offset : offset + config.frame_size]).astype(
|
17
|
+
np.complex64
|
18
|
+
)
|
19
|
+
noise_f = config.noise_fft.execute(data.noise_audio[offset : offset + config.frame_size]).astype(np.complex64)
|
19
20
|
mixture_f = target_f + noise_f
|
20
21
|
|
21
22
|
crm_data = np.empty(target_f.shape, dtype=np.complex64)
|
22
|
-
with np.nditer(target_f, flags=[
|
23
|
+
with np.nditer(target_f, flags=["multi_index"], op_flags=[["readwrite"]]) as it:
|
23
24
|
for _ in it:
|
24
25
|
num = target_f[it.multi_index]
|
25
26
|
den = mixture_f[it.multi_index]
|
@@ -30,44 +31,61 @@ def _core(data: Data, polar: bool) -> Truth:
|
|
30
31
|
else:
|
31
32
|
crm_data[it.multi_index] = num / den
|
32
33
|
|
33
|
-
|
34
|
+
truth[frame, : config.target_fft.bins] = np.absolute(crm_data) if polar else np.real(crm_data)
|
35
|
+
truth[frame, config.target_fft.bins :] = np.angle(crm_data) if polar else np.imag(crm_data)
|
36
|
+
|
37
|
+
return truth
|
34
38
|
|
35
|
-
def c1(c_data: np.ndarray, is_polar: bool) -> np.ndarray:
|
36
|
-
if is_polar:
|
37
|
-
return np.absolute(c_data)
|
38
|
-
return np.real(c_data)
|
39
39
|
|
40
|
-
|
41
|
-
|
42
|
-
return np.angle(c_data)
|
43
|
-
return np.imag(c_data)
|
40
|
+
def crm_validate(_config: dict) -> None:
|
41
|
+
pass
|
44
42
|
|
45
|
-
for index in data.zero_based_indices:
|
46
|
-
data.truth[indices, index:index + data.target_fft.bins] = c1(crm_data, polar)
|
47
|
-
data.truth[indices, (index + data.target_fft.bins):(index + 2 * data.target_fft.bins)] = c2(crm_data, polar)
|
48
43
|
|
49
|
-
|
44
|
+
def crm_parameters(config: TruthFunctionConfig) -> int:
|
45
|
+
return config.target_fft.bins * 2
|
50
46
|
|
51
47
|
|
52
|
-
def crm(data:
|
48
|
+
def crm(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
53
49
|
"""Complex ratio mask truth generation function
|
54
50
|
|
55
|
-
Calculates the true complex ratio mask (CRM) truth which is a complex number
|
56
|
-
per bin = Mr + j*Mi. For a given noisy STFT bin value Y, it is used as
|
51
|
+
Calculates the true complex ratio mask (CRM) truth which is a complex number
|
52
|
+
per bin = Mr + j*Mi. For a given noisy STFT bin value Y, it is used as
|
57
53
|
|
58
|
-
(Mr*Yr + Mi*Yi) / (Yr^2 + Yi^2) + j*(Mi*Yr - Mr*Yi)/ (Yr^2 + Yi^2)
|
54
|
+
(Mr*Yr + Mi*Yi) / (Yr^2 + Yi^2) + j*(Mi*Yr - Mr*Yi)/ (Yr^2 + Yi^2)
|
59
55
|
|
60
|
-
Output shape: [:, bins]
|
56
|
+
Output shape: [:, 2 * bins]
|
61
57
|
"""
|
62
|
-
|
58
|
+
import numpy as np
|
59
|
+
|
60
|
+
frames = config.target_fft.frames(data.target_audio)
|
61
|
+
parameters = crm_parameters(config)
|
62
|
+
if config.target_gain == 0:
|
63
|
+
return np.zeros((frames, parameters), dtype=np.float32)
|
64
|
+
|
65
|
+
return _core(data=data, config=config, polar=False)
|
63
66
|
|
64
67
|
|
65
|
-
def
|
68
|
+
def crmp_validate(_config: dict) -> None:
|
69
|
+
pass
|
70
|
+
|
71
|
+
|
72
|
+
def crmp_parameters(config: TruthFunctionConfig) -> int:
|
73
|
+
return config.target_fft.bins * 2
|
74
|
+
|
75
|
+
|
76
|
+
def crmp(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
66
77
|
"""Complex ratio mask polar truth generation function
|
67
78
|
|
68
|
-
Same as the crm function except the results are magnitude and phase
|
69
|
-
instead of real and imaginary.
|
79
|
+
Same as the crm function except the results are magnitude and phase
|
80
|
+
instead of real and imaginary.
|
70
81
|
|
71
|
-
Output shape: [:, bins]
|
82
|
+
Output shape: [:, bins]
|
72
83
|
"""
|
73
|
-
|
84
|
+
import numpy as np
|
85
|
+
|
86
|
+
frames = config.target_fft.frames(data.target_audio)
|
87
|
+
parameters = crmp_parameters(config)
|
88
|
+
if config.target_gain == 0:
|
89
|
+
return np.zeros((frames, parameters), dtype=np.float32)
|
90
|
+
|
91
|
+
return _core(data=data, config=config, polar=True)
|
@@ -0,0 +1,37 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
from sonusai.mixture.datatypes import AudioT
|
4
|
+
|
5
|
+
|
6
|
+
class TruthFunctionConfig:
|
7
|
+
def __init__(self, feature: str, num_classes: int, class_indices: list[int], target_gain: float, config: dict):
|
8
|
+
from pyaaware import ForwardTransform
|
9
|
+
from pyaaware import InverseTransform
|
10
|
+
from pyaaware import feature_forward_transform_config
|
11
|
+
from pyaaware import feature_inverse_transform_config
|
12
|
+
from pyaaware import feature_parameters
|
13
|
+
|
14
|
+
self.feature = feature
|
15
|
+
self.num_classes = num_classes
|
16
|
+
self.class_indices = class_indices
|
17
|
+
self.target_gain = target_gain
|
18
|
+
self.config = config
|
19
|
+
|
20
|
+
self.feature_parameters = feature_parameters(feature)
|
21
|
+
ft_config = feature_forward_transform_config(feature)
|
22
|
+
it_config = feature_inverse_transform_config(feature)
|
23
|
+
|
24
|
+
self.ttype = it_config["ttype"]
|
25
|
+
self.frame_size = it_config["overlap"]
|
26
|
+
|
27
|
+
self.target_fft = ForwardTransform(**ft_config)
|
28
|
+
self.noise_fft = ForwardTransform(**ft_config)
|
29
|
+
self.mixture_fft = ForwardTransform(**ft_config)
|
30
|
+
self.swin = InverseTransform(**it_config).window
|
31
|
+
|
32
|
+
|
33
|
+
@dataclass
|
34
|
+
class TruthFunctionData:
|
35
|
+
target_audio: AudioT
|
36
|
+
noise_audio: AudioT
|
37
|
+
mixture_audio: AudioT
|
@@ -1,69 +1,43 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
3
|
from sonusai.mixture.datatypes import Truth
|
4
|
-
from sonusai.mixture.truth_functions.
|
4
|
+
from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
|
5
|
+
from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
|
5
6
|
|
6
7
|
|
7
|
-
def _core(data:
|
8
|
-
from sonusai import SonusAIError
|
8
|
+
def _core(data: TruthFunctionData, config: TruthFunctionConfig, mapped: bool, snr: bool) -> Truth:
|
9
9
|
from sonusai.utils import compute_energy_f
|
10
10
|
|
11
|
-
|
12
|
-
snr_db_std = None
|
13
|
-
if mapped:
|
14
|
-
if data.config.config is None:
|
15
|
-
raise SonusAIError('Truth function mapped SNR missing config')
|
16
|
-
|
17
|
-
parameters = ['snr_db_mean', 'snr_db_std']
|
18
|
-
for parameter in parameters:
|
19
|
-
if parameter not in data.config.config:
|
20
|
-
raise SonusAIError(f'Truth function mapped_snr_f config missing required parameter: {parameter}')
|
21
|
-
|
22
|
-
snr_db_mean = data.config.config['snr_db_mean']
|
23
|
-
if len(snr_db_mean) != data.target_fft.bins:
|
24
|
-
raise SonusAIError(f'Truth function mapped_snr_f snr_db_mean does not have {data.target_fft.bins} elements')
|
25
|
-
|
26
|
-
snr_db_std = data.config.config['snr_db_std']
|
27
|
-
if len(snr_db_std) != data.target_fft.bins:
|
28
|
-
raise SonusAIError(f'Truth function mapped_snr_f snr_db_std does not have {data.target_fft.bins} elements')
|
29
|
-
|
30
|
-
for index in data.zero_based_indices:
|
31
|
-
if index + data.target_fft.bins > data.config.num_classes:
|
32
|
-
raise SonusAIError('Truth index exceeds the number of classes')
|
33
|
-
|
34
|
-
target_energy = compute_energy_f(time_domain=data.target_audio, transform=data.target_fft)
|
11
|
+
target_energy = compute_energy_f(time_domain=data.target_audio, transform=config.target_fft)
|
35
12
|
noise_energy = None
|
36
13
|
if snr:
|
37
|
-
noise_energy = compute_energy_f(time_domain=data.noise_audio, transform=
|
38
|
-
|
39
|
-
if len(target_energy) != len(data.offsets):
|
40
|
-
raise SonusAIError(f'Number of frames in target_energy, {len(target_energy)},'
|
41
|
-
f' is not number of frames in truth, {len(data.offsets)}')
|
14
|
+
noise_energy = compute_energy_f(time_domain=data.noise_audio, transform=config.noise_fft)
|
42
15
|
|
43
|
-
|
44
|
-
|
16
|
+
frames = len(target_energy)
|
17
|
+
truth = np.empty((frames, config.target_fft.bins), dtype=np.float32)
|
18
|
+
for frame in range(frames):
|
19
|
+
tmp = target_energy[frame]
|
45
20
|
|
46
|
-
if
|
47
|
-
old_err = np.seterr(divide=
|
48
|
-
tmp /= noise_energy[
|
21
|
+
if noise_energy is not None:
|
22
|
+
old_err = np.seterr(divide="ignore", invalid="ignore")
|
23
|
+
tmp /= noise_energy[frame]
|
49
24
|
np.seterr(**old_err)
|
50
25
|
|
51
26
|
tmp = np.nan_to_num(tmp, nan=-np.inf, posinf=np.inf, neginf=-np.inf)
|
52
27
|
|
53
28
|
if mapped:
|
54
|
-
tmp = _calculate_mapped_snr_f(tmp, snr_db_mean, snr_db_std)
|
29
|
+
tmp = _calculate_mapped_snr_f(tmp, config.config["snr_db_mean"], config.config["snr_db_std"])
|
55
30
|
|
56
|
-
|
57
|
-
data.truth[offset:offset + data.frame_size, index:index + data.target_fft.bins] = tmp
|
31
|
+
truth[frame] = tmp
|
58
32
|
|
59
|
-
return
|
33
|
+
return truth
|
60
34
|
|
61
35
|
|
62
36
|
def _calculate_mapped_snr_f(truth_f: np.ndarray, snr_db_mean: np.ndarray, snr_db_std: np.ndarray) -> np.ndarray:
|
63
37
|
"""Calculate mapped SNR from standard SNR energy per bin/class."""
|
64
38
|
import scipy.special as sc
|
65
39
|
|
66
|
-
old_err = np.seterr(divide=
|
40
|
+
old_err = np.seterr(divide="ignore", invalid="ignore")
|
67
41
|
num = 10 * np.log10(np.double(truth_f)) - np.double(snr_db_mean)
|
68
42
|
den = np.double(snr_db_std) * np.sqrt(2)
|
69
43
|
q = num / den
|
@@ -74,7 +48,15 @@ def _calculate_mapped_snr_f(truth_f: np.ndarray, snr_db_mean: np.ndarray, snr_db
|
|
74
48
|
return result.astype(np.float32)
|
75
49
|
|
76
50
|
|
77
|
-
def
|
51
|
+
def energy_f_validate(_config: dict) -> None:
|
52
|
+
pass
|
53
|
+
|
54
|
+
|
55
|
+
def energy_f_parameters(config: TruthFunctionConfig) -> int:
|
56
|
+
return config.target_fft.bins
|
57
|
+
|
58
|
+
|
59
|
+
def energy_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
78
60
|
"""Frequency domain energy truth generation function
|
79
61
|
|
80
62
|
Calculates the true energy per bin:
|
@@ -85,10 +67,23 @@ def energy_f(data: Data) -> Truth:
|
|
85
67
|
|
86
68
|
Output shape: [:, bins]
|
87
69
|
"""
|
88
|
-
|
70
|
+
frames = config.target_fft.frames(data.target_audio)
|
71
|
+
parameters = energy_f_parameters(config)
|
72
|
+
if config.target_gain == 0:
|
73
|
+
return np.zeros((frames, parameters), dtype=np.float32)
|
74
|
+
|
75
|
+
return _core(data=data, config=config, mapped=False, snr=False)
|
76
|
+
|
77
|
+
|
78
|
+
def snr_f_validate(_config: dict) -> None:
|
79
|
+
pass
|
80
|
+
|
81
|
+
|
82
|
+
def snr_f_parameters(config: TruthFunctionConfig) -> int:
|
83
|
+
return config.target_fft.bins
|
89
84
|
|
90
85
|
|
91
|
-
def snr_f(data:
|
86
|
+
def snr_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
92
87
|
"""Frequency domain SNR truth function documentation
|
93
88
|
|
94
89
|
Calculates the true SNR per bin:
|
@@ -99,18 +94,54 @@ def snr_f(data: Data) -> Truth:
|
|
99
94
|
|
100
95
|
Output shape: [:, bins]
|
101
96
|
"""
|
102
|
-
|
97
|
+
frames = config.target_fft.frames(data.target_audio)
|
98
|
+
parameters = snr_f_parameters(config)
|
99
|
+
if config.target_gain == 0:
|
100
|
+
return np.zeros((frames, parameters), dtype=np.float32)
|
103
101
|
|
102
|
+
return _core(data=data, config=config, mapped=False, snr=True)
|
104
103
|
|
105
|
-
|
104
|
+
|
105
|
+
def mapped_snr_f_validate(config: TruthFunctionConfig) -> None:
|
106
|
+
if len(config.config) == 0:
|
107
|
+
raise AttributeError("mapped_snr_f truth function is missing config")
|
108
|
+
|
109
|
+
for parameter in ("snr_db_mean", "snr_db_std"):
|
110
|
+
if parameter not in config.config:
|
111
|
+
raise AttributeError(f"mapped_snr_f truth function is missing required '{parameter}'")
|
112
|
+
|
113
|
+
if len(config.config[parameter]) != config.target_fft.bins:
|
114
|
+
raise ValueError(
|
115
|
+
f"mapped_snr_f truth function '{parameter}' does not have {config.target_fft.bins} elements"
|
116
|
+
)
|
117
|
+
|
118
|
+
|
119
|
+
def mapped_snr_f_parameters(config: TruthFunctionConfig) -> int:
|
120
|
+
return config.target_fft.bins
|
121
|
+
|
122
|
+
|
123
|
+
def mapped_snr_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
106
124
|
"""Frequency domain mapped SNR truth function documentation
|
107
125
|
|
108
126
|
Output shape: [:, bins]
|
109
127
|
"""
|
110
|
-
|
128
|
+
frames = config.target_fft.frames(data.target_audio)
|
129
|
+
parameters = mapped_snr_f_parameters(config)
|
130
|
+
if config.target_gain == 0:
|
131
|
+
return np.zeros((frames, parameters), dtype=np.float32)
|
111
132
|
|
133
|
+
return _core(data=data, config=config, mapped=True, snr=True)
|
112
134
|
|
113
|
-
|
135
|
+
|
136
|
+
def energy_t_validate(_config: dict) -> None:
|
137
|
+
pass
|
138
|
+
|
139
|
+
|
140
|
+
def energy_t_parameters(_config: TruthFunctionConfig) -> int:
|
141
|
+
return 1
|
142
|
+
|
143
|
+
|
144
|
+
def energy_t(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
114
145
|
"""Time domain energy truth function documentation
|
115
146
|
|
116
147
|
Calculates the true time domain energy of each frame:
|
@@ -134,14 +165,9 @@ def energy_t(data: Data) -> Truth:
|
|
134
165
|
"""
|
135
166
|
import torch
|
136
167
|
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
raise SonusAIError(f'Number of frames in target_energy, {len(target_energy)},'
|
142
|
-
f' is not number of frames in truth, {len(data.offsets)}')
|
143
|
-
|
144
|
-
for offset in data.offsets:
|
145
|
-
data.truth[offset:offset + data.frame_size, data.zero_based_indices] = np.float32(target_energy)
|
168
|
+
frames = config.target_fft.frames(data.target_audio)
|
169
|
+
parameters = energy_t_parameters(config)
|
170
|
+
if config.target_gain == 0:
|
171
|
+
return np.zeros((frames, parameters), dtype=np.float32)
|
146
172
|
|
147
|
-
return data.
|
173
|
+
return config.target_fft.execute_all(torch.from_numpy(data.target_audio))[1].numpy()
|
@@ -1,44 +1,44 @@
|
|
1
1
|
from sonusai.mixture.datatypes import Truth
|
2
|
-
from sonusai.mixture.truth_functions.
|
2
|
+
from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
|
3
|
+
from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
|
3
4
|
|
4
5
|
|
5
|
-
def
|
6
|
-
"""file truth function documentation
|
7
|
-
"""
|
6
|
+
def file_validate(config: dict) -> None:
|
8
7
|
import h5py
|
9
|
-
import numpy as np
|
10
8
|
|
11
|
-
|
9
|
+
if len(config) == 0:
|
10
|
+
raise AttributeError("file truth function is missing config")
|
11
|
+
|
12
|
+
if "file" not in config:
|
13
|
+
raise AttributeError("file truth function is missing required 'file'")
|
14
|
+
|
15
|
+
with h5py.File(config["file"], "r") as f:
|
16
|
+
if "truth_f" not in f:
|
17
|
+
raise ValueError("Truth file does not contain truth_f dataset")
|
18
|
+
|
12
19
|
|
13
|
-
|
14
|
-
|
20
|
+
def file_parameters(config: TruthFunctionConfig) -> int:
|
21
|
+
import h5py
|
22
|
+
import numpy as np
|
15
23
|
|
16
|
-
|
17
|
-
|
18
|
-
if 'file' not in data.config.config:
|
19
|
-
raise SonusAIError(f'Truth function file config missing required parameter: {parameter}')
|
24
|
+
with h5py.File(config.config["file"], "r") as f:
|
25
|
+
truth = np.array(f["truth_f"])
|
20
26
|
|
21
|
-
|
22
|
-
if 'truth_t' not in f:
|
23
|
-
raise SonusAIError('Truth file does not contain truth_t dataset')
|
24
|
-
truth_in = np.array(f['truth_t'])
|
27
|
+
return truth.shape[-1]
|
25
28
|
|
26
|
-
if truth_in.ndim != 2:
|
27
|
-
raise SonusAIError('Truth file data is not 2 dimensions')
|
28
29
|
|
29
|
-
|
30
|
-
|
30
|
+
def file(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
31
|
+
"""file truth function documentation"""
|
32
|
+
import h5py
|
33
|
+
import numpy as np
|
31
34
|
|
32
|
-
|
33
|
-
|
34
|
-
raise SonusAIError('Truth file does not contain the right amount of classes')
|
35
|
+
with h5py.File(config.config["file"], "r") as f:
|
36
|
+
truth = np.array(f["truth_f"])
|
35
37
|
|
36
|
-
|
37
|
-
|
38
|
-
index = data.zero_based_indices[0]
|
39
|
-
if index + truth_in.shape[1] > data.config.num_classes:
|
40
|
-
raise SonusAIError('Truth file contains too many classes')
|
38
|
+
if truth.ndim != 2:
|
39
|
+
raise ValueError("Truth file data is not 2 dimensions")
|
41
40
|
|
42
|
-
|
41
|
+
if truth.shape[0] != len(data.target_audio) // config.frame_size:
|
42
|
+
raise ValueError("Truth file does not contain the right amount of frames")
|
43
43
|
|
44
|
-
return
|
44
|
+
return truth
|
@@ -1,12 +1,19 @@
|
|
1
1
|
from sonusai.mixture.datatypes import Truth
|
2
|
-
from sonusai.mixture.truth_functions.
|
2
|
+
from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
|
3
|
+
from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
|
3
4
|
|
4
5
|
|
5
|
-
def
|
6
|
+
def phoneme_validate(_config: dict) -> None:
|
7
|
+
raise NotImplementedError("Truth function phoneme is not supported yet")
|
8
|
+
|
9
|
+
|
10
|
+
def phoneme_parameters(_config: TruthFunctionConfig) -> int:
|
11
|
+
raise NotImplementedError("Truth function phoneme is not supported yet")
|
12
|
+
|
13
|
+
|
14
|
+
def phoneme(_data: TruthFunctionData, _config: TruthFunctionConfig) -> Truth:
|
6
15
|
"""Read in .txt transcript and run a Python function to generate text grid data
|
7
|
-
(indicating which phonemes are active). Then generate truth based on this data and put
|
8
|
-
in the correct classes based on the index in the config.
|
16
|
+
(indicating which phonemes are active). Then generate truth based on this data and put
|
17
|
+
in the correct classes based on the index in the config.
|
9
18
|
"""
|
10
|
-
|
11
|
-
|
12
|
-
raise SonusAIError('Truth function phoneme is not supported yet')
|
19
|
+
raise NotImplementedError("Truth function phoneme is not supported yet")
|