sonusai 1.0.16__cp311-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,51 @@
|
|
1
|
+
from ..datatypes import AudioF
|
2
|
+
from ..datatypes import SpectralMask
|
3
|
+
|
4
|
+
|
5
|
+
def apply_spectral_mask(audio_f: AudioF, spectral_mask: SpectralMask, seed: int | None = None) -> AudioF:
|
6
|
+
"""Apply frequency and time masking
|
7
|
+
|
8
|
+
Implementation of SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
|
9
|
+
|
10
|
+
Ref: https://arxiv.org/pdf/1904.08779.pdf
|
11
|
+
|
12
|
+
f_width consecutive bins [f_start, f_start + f_width) are masked, where f_width is chosen from a uniform
|
13
|
+
distribution from 0 to the f_max_width, and f_start is chosen from [0, bins - f_width).
|
14
|
+
|
15
|
+
t_width consecutive frames [t_start, t_start + t_width) are masked, where t_width is chosen from a uniform
|
16
|
+
distribution from 0 to the t_max_width, and t_start is chosen from [0, frames - t_width).
|
17
|
+
|
18
|
+
A time mask cannot be wider than t_max_percent times the number of frames.
|
19
|
+
|
20
|
+
:param audio_f: Numpy array of transform audio data [frames, bins]
|
21
|
+
:param spectral_mask: Spectral mask parameters
|
22
|
+
:param seed: Random number seed
|
23
|
+
:return: Augmented feature
|
24
|
+
"""
|
25
|
+
import numpy as np
|
26
|
+
|
27
|
+
if audio_f.ndim != 2:
|
28
|
+
raise ValueError("feature input must have three dimensions [frames, bins]")
|
29
|
+
|
30
|
+
frames, bins = audio_f.shape
|
31
|
+
|
32
|
+
f_max_width = spectral_mask.f_max_width
|
33
|
+
if f_max_width not in range(0, bins + 1):
|
34
|
+
f_max_width = bins
|
35
|
+
|
36
|
+
rng = np.random.default_rng(seed)
|
37
|
+
|
38
|
+
# apply f_num frequency masks to the feature
|
39
|
+
for _ in range(spectral_mask.f_num):
|
40
|
+
f_width = int(rng.uniform(0, f_max_width))
|
41
|
+
f_start = rng.integers(0, bins - f_width, endpoint=True)
|
42
|
+
audio_f[:, f_start : f_start + f_width] = 0
|
43
|
+
|
44
|
+
# apply t_num time masks to the feature
|
45
|
+
t_upper_bound = int(spectral_mask.t_max_percent / 100 * frames)
|
46
|
+
for _ in range(spectral_mask.t_num):
|
47
|
+
t_width = min(int(rng.uniform(0, spectral_mask.t_max_width)), t_upper_bound)
|
48
|
+
t_start = rng.integers(0, frames - t_width, endpoint=True)
|
49
|
+
audio_f[t_start : t_start + t_width, :] = 0
|
50
|
+
|
51
|
+
return audio_f
|
sonusai/mixture/truth.py
ADDED
@@ -0,0 +1,61 @@
|
|
1
|
+
from ..datatypes import Truth
|
2
|
+
from ..datatypes import TruthsDict
|
3
|
+
from .mixdb import MixtureDatabase
|
4
|
+
|
5
|
+
|
6
|
+
def truth_function(mixdb: MixtureDatabase, m_id: int) -> TruthsDict:
|
7
|
+
from ..datatypes import TruthDict
|
8
|
+
from . import truth_functions
|
9
|
+
|
10
|
+
result: TruthsDict = {}
|
11
|
+
for category, source in mixdb.mixture(m_id).all_sources.items():
|
12
|
+
truth: TruthDict = {}
|
13
|
+
source_file = mixdb.source_file(source.file_id)
|
14
|
+
for name, config in source_file.truth_configs.items():
|
15
|
+
try:
|
16
|
+
truth[name] = getattr(truth_functions, config.function)(mixdb, m_id, category, config.config)
|
17
|
+
except AttributeError as e:
|
18
|
+
raise AttributeError(f"Unsupported truth function: {config.function}") from e
|
19
|
+
except Exception as e:
|
20
|
+
raise RuntimeError(f"Error in truth function '{config.function}': {e}") from e
|
21
|
+
|
22
|
+
if truth:
|
23
|
+
result[category] = truth
|
24
|
+
|
25
|
+
return result
|
26
|
+
|
27
|
+
|
28
|
+
def get_class_indices_for_mixid(mixdb: MixtureDatabase, mixid: int) -> list[int]:
|
29
|
+
"""Get a list of class indices for a given mixid."""
|
30
|
+
indices: list[int] = []
|
31
|
+
for source_id in [source.file_id for source in mixdb.mixture(mixid).all_sources.values()]:
|
32
|
+
indices.append(*mixdb.source_file(source_id).class_indices)
|
33
|
+
|
34
|
+
return sorted(set(indices))
|
35
|
+
|
36
|
+
|
37
|
+
def truth_stride_reduction(truth: Truth, function: str) -> Truth:
|
38
|
+
"""Reduce stride dimension of truth.
|
39
|
+
|
40
|
+
:param truth: Truth data [frames, stride, truth_parameters]
|
41
|
+
:param function: Truth stride reduction function name
|
42
|
+
:return: Stride reduced truth data [frames, stride or 1, truth_parameters]
|
43
|
+
"""
|
44
|
+
import numpy as np
|
45
|
+
|
46
|
+
if truth.ndim != 3:
|
47
|
+
raise ValueError("Invalid truth shape")
|
48
|
+
|
49
|
+
if function == "none":
|
50
|
+
return truth
|
51
|
+
|
52
|
+
if function == "max":
|
53
|
+
return np.max(truth, axis=1, keepdims=True)
|
54
|
+
|
55
|
+
if function == "mean":
|
56
|
+
return np.mean(truth, axis=1, keepdims=True)
|
57
|
+
|
58
|
+
if function == "first":
|
59
|
+
return truth[:, 0, :].reshape((truth.shape[0], 1, truth.shape[2]))
|
60
|
+
|
61
|
+
raise ValueError(f"Invalid truth stride reduction function: {function}")
|
@@ -0,0 +1,45 @@
|
|
1
|
+
# SonusAI truth functions
|
2
|
+
# ruff: noqa: F401
|
3
|
+
|
4
|
+
from .crm import crm
|
5
|
+
from .crm import crm_parameters
|
6
|
+
from .crm import crm_validate
|
7
|
+
from .crm import crmp
|
8
|
+
from .crm import crmp_parameters
|
9
|
+
from .crm import crmp_validate
|
10
|
+
from .energy import energy_f
|
11
|
+
from .energy import energy_f_parameters
|
12
|
+
from .energy import energy_f_validate
|
13
|
+
from .energy import energy_t
|
14
|
+
from .energy import energy_t_parameters
|
15
|
+
from .energy import energy_t_validate
|
16
|
+
from .energy import mapped_snr_f
|
17
|
+
from .energy import mapped_snr_f_parameters
|
18
|
+
from .energy import mapped_snr_f_validate
|
19
|
+
from .energy import snr_f
|
20
|
+
from .energy import snr_f_parameters
|
21
|
+
from .energy import snr_f_validate
|
22
|
+
from .file import file
|
23
|
+
from .file import file_parameters
|
24
|
+
from .file import file_validate
|
25
|
+
from .metadata import metadata
|
26
|
+
from .metadata import metadata_parameters
|
27
|
+
from .metadata import metadata_validate
|
28
|
+
from .metrics import metrics
|
29
|
+
from .metrics import metrics_parameters
|
30
|
+
from .metrics import metrics_validate
|
31
|
+
from .phoneme import phoneme
|
32
|
+
from .phoneme import phoneme_parameters
|
33
|
+
from .phoneme import phoneme_validate
|
34
|
+
from .sed import sed
|
35
|
+
from .sed import sed_parameters
|
36
|
+
from .sed import sed_validate
|
37
|
+
from .target import target_f
|
38
|
+
from .target import target_f_parameters
|
39
|
+
from .target import target_f_validate
|
40
|
+
from .target import target_mixture_f
|
41
|
+
from .target import target_mixture_f_parameters
|
42
|
+
from .target import target_mixture_f_validate
|
43
|
+
from .target import target_swin_f
|
44
|
+
from .target import target_swin_f_parameters
|
45
|
+
from .target import target_swin_f_validate
|
@@ -0,0 +1,105 @@
|
|
1
|
+
from ...datatypes import Truth
|
2
|
+
from ..mixdb import MixtureDatabase
|
3
|
+
|
4
|
+
|
5
|
+
def _core(mixdb: MixtureDatabase, m_id: int, category: str, parameters: int, polar: bool) -> Truth:
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
from pyaaware import ForwardTransform
|
9
|
+
from pyaaware import feature_forward_transform_config
|
10
|
+
from pyaaware import feature_inverse_transform_config
|
11
|
+
|
12
|
+
source_audio = torch.from_numpy(mixdb.mixture_sources(m_id)[category])
|
13
|
+
t_ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
14
|
+
n_ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
15
|
+
|
16
|
+
frames = t_ft.frames(source_audio)
|
17
|
+
if mixdb.mixture(m_id).all_sources[category].snr_gain == 0:
|
18
|
+
return np.zeros((frames, parameters), dtype=np.float32)
|
19
|
+
|
20
|
+
noise_audio = torch.from_numpy(mixdb.mixture_noise(m_id))
|
21
|
+
|
22
|
+
frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
|
23
|
+
|
24
|
+
frames = len(source_audio) // frame_size
|
25
|
+
truth = np.empty((frames, t_ft.bins * 2), dtype=np.float32)
|
26
|
+
for frame in range(frames):
|
27
|
+
offset = frame * frame_size
|
28
|
+
target_f = t_ft.execute(source_audio[offset : offset + frame_size])[0].numpy().astype(np.complex64)
|
29
|
+
noise_f = n_ft.execute(noise_audio[offset : offset + frame_size])[0].numpy().astype(np.complex64)
|
30
|
+
mixture_f = target_f + noise_f
|
31
|
+
|
32
|
+
crm_data = np.empty(target_f.shape, dtype=np.complex64)
|
33
|
+
with np.nditer(target_f, flags=["multi_index"], op_flags=[["readwrite"]]) as it:
|
34
|
+
for _ in it:
|
35
|
+
num = target_f[it.multi_index]
|
36
|
+
den = mixture_f[it.multi_index]
|
37
|
+
if num == 0:
|
38
|
+
crm_data[it.multi_index] = 0
|
39
|
+
elif den == 0:
|
40
|
+
crm_data[it.multi_index] = complex(np.inf, np.inf)
|
41
|
+
else:
|
42
|
+
crm_data[it.multi_index] = num / den
|
43
|
+
|
44
|
+
truth[frame, : t_ft.bins] = np.absolute(crm_data) if polar else np.real(crm_data)
|
45
|
+
truth[frame, t_ft.bins :] = np.angle(crm_data) if polar else np.imag(crm_data)
|
46
|
+
|
47
|
+
return truth
|
48
|
+
|
49
|
+
|
50
|
+
def crm_validate(_config: dict) -> None:
|
51
|
+
pass
|
52
|
+
|
53
|
+
|
54
|
+
def crm_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
55
|
+
from pyaaware import ForwardTransform
|
56
|
+
from pyaaware import feature_forward_transform_config
|
57
|
+
|
58
|
+
return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
|
59
|
+
|
60
|
+
|
61
|
+
def crm(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
|
62
|
+
"""Complex ratio mask truth generation function
|
63
|
+
|
64
|
+
Calculates the true complex ratio mask (CRM) truth which is a complex number
|
65
|
+
per bin = Mr + j*Mi. For a given noisy STFT bin value Y, it is used as
|
66
|
+
|
67
|
+
(Mr*Yr + Mi*Yi) / (Yr^2 + Yi^2) + j*(Mi*Yr - Mr*Yi)/ (Yr^2 + Yi^2)
|
68
|
+
|
69
|
+
Output shape: [:, 2 * bins]
|
70
|
+
"""
|
71
|
+
return _core(
|
72
|
+
mixdb=mixdb,
|
73
|
+
m_id=m_id,
|
74
|
+
category=category,
|
75
|
+
parameters=crm_parameters(mixdb.feature, mixdb.num_classes, _config),
|
76
|
+
polar=False,
|
77
|
+
)
|
78
|
+
|
79
|
+
|
80
|
+
def crmp_validate(_config: dict) -> None:
|
81
|
+
pass
|
82
|
+
|
83
|
+
|
84
|
+
def crmp_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
85
|
+
from pyaaware import ForwardTransform
|
86
|
+
from pyaaware import feature_forward_transform_config
|
87
|
+
|
88
|
+
return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
|
89
|
+
|
90
|
+
|
91
|
+
def crmp(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
|
92
|
+
"""Complex ratio mask polar truth generation function
|
93
|
+
|
94
|
+
Same as the crm function except the results are magnitude and phase
|
95
|
+
instead of real and imaginary.
|
96
|
+
|
97
|
+
Output shape: [:, bins]
|
98
|
+
"""
|
99
|
+
return _core(
|
100
|
+
mixdb=mixdb,
|
101
|
+
m_id=m_id,
|
102
|
+
category=category,
|
103
|
+
parameters=crmp_parameters(mixdb.feature, mixdb.num_classes, _config),
|
104
|
+
polar=True,
|
105
|
+
)
|
@@ -0,0 +1,222 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
from ...datatypes import Truth
|
4
|
+
from ...utils.load_object import load_object
|
5
|
+
from ..mixdb import MixtureDatabase
|
6
|
+
|
7
|
+
|
8
|
+
def _core(
|
9
|
+
mixdb: MixtureDatabase,
|
10
|
+
m_id: int,
|
11
|
+
category: str,
|
12
|
+
config: dict,
|
13
|
+
parameters: int,
|
14
|
+
mapped: bool,
|
15
|
+
snr: bool,
|
16
|
+
use_cache: bool = True,
|
17
|
+
) -> Truth:
|
18
|
+
from os.path import join
|
19
|
+
|
20
|
+
import torch
|
21
|
+
from pyaaware import ForwardTransform
|
22
|
+
from pyaaware import feature_forward_transform_config
|
23
|
+
|
24
|
+
from ...utils.energy_f import compute_energy_f
|
25
|
+
|
26
|
+
source_audio = mixdb.mixture_sources(m_id)[category]
|
27
|
+
ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
28
|
+
|
29
|
+
frames = ft.frames(torch.from_numpy(source_audio))
|
30
|
+
|
31
|
+
if mixdb.mixture(m_id).all_sources[category].snr_gain == 0:
|
32
|
+
return np.zeros((frames, parameters), dtype=np.float32)
|
33
|
+
|
34
|
+
noise_audio = mixdb.mixture_noise(m_id)
|
35
|
+
|
36
|
+
source_energy = compute_energy_f(time_domain=source_audio, transform=ft)
|
37
|
+
noise_energy = None
|
38
|
+
if snr:
|
39
|
+
noise_energy = compute_energy_f(time_domain=noise_audio, transform=ft)
|
40
|
+
|
41
|
+
frames = len(source_energy)
|
42
|
+
truth = np.empty((frames, ft.bins), dtype=np.float32)
|
43
|
+
for frame in range(frames):
|
44
|
+
tmp = source_energy[frame]
|
45
|
+
|
46
|
+
if noise_energy is not None:
|
47
|
+
old_err = np.seterr(divide="ignore", invalid="ignore")
|
48
|
+
tmp /= noise_energy[frame]
|
49
|
+
np.seterr(**old_err)
|
50
|
+
|
51
|
+
tmp = np.nan_to_num(tmp, nan=-np.inf, posinf=np.inf, neginf=-np.inf)
|
52
|
+
|
53
|
+
if mapped:
|
54
|
+
snr_db_mean = load_object(join(mixdb.location, config["snr_db_mean"]), use_cache)
|
55
|
+
snr_db_std = load_object(join(mixdb.location, config["snr_db_std"]), use_cache)
|
56
|
+
tmp = _calculate_mapped_snr_f(tmp, snr_db_mean, snr_db_std)
|
57
|
+
|
58
|
+
truth[frame] = tmp
|
59
|
+
|
60
|
+
return truth
|
61
|
+
|
62
|
+
|
63
|
+
def _calculate_mapped_snr_f(truth_f: np.ndarray, snr_db_mean: np.ndarray, snr_db_std: np.ndarray) -> np.ndarray:
|
64
|
+
"""Calculate mapped SNR from standard SNR energy per bin/class."""
|
65
|
+
import scipy.special as sc
|
66
|
+
|
67
|
+
old_err = np.seterr(divide="ignore", invalid="ignore")
|
68
|
+
num = 10 * np.log10(np.double(truth_f)) - np.double(snr_db_mean)
|
69
|
+
den = np.double(snr_db_std) * np.sqrt(2)
|
70
|
+
q = num / den
|
71
|
+
q = np.nan_to_num(q, nan=-np.inf, posinf=np.inf, neginf=-np.inf)
|
72
|
+
result = 0.5 * (1 + sc.erf(q))
|
73
|
+
np.seterr(**old_err)
|
74
|
+
|
75
|
+
return result.astype(np.float32)
|
76
|
+
|
77
|
+
|
78
|
+
def energy_f_validate(_config: dict) -> None:
|
79
|
+
pass
|
80
|
+
|
81
|
+
|
82
|
+
def energy_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
83
|
+
from pyaaware import ForwardTransform
|
84
|
+
from pyaaware import feature_forward_transform_config
|
85
|
+
|
86
|
+
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
87
|
+
|
88
|
+
|
89
|
+
def energy_f(mixdb: MixtureDatabase, m_id: int, category: str, config: dict, use_cache: bool = True) -> Truth:
|
90
|
+
"""Frequency domain energy truth generation function
|
91
|
+
|
92
|
+
Calculates the true energy per bin:
|
93
|
+
|
94
|
+
Ti^2 + Tr^2
|
95
|
+
|
96
|
+
where T is the target STFT bin values.
|
97
|
+
|
98
|
+
Output shape: [:, bins]
|
99
|
+
"""
|
100
|
+
return _core(
|
101
|
+
mixdb=mixdb,
|
102
|
+
m_id=m_id,
|
103
|
+
category=category,
|
104
|
+
config=config,
|
105
|
+
parameters=energy_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
106
|
+
mapped=False,
|
107
|
+
snr=False,
|
108
|
+
use_cache=use_cache,
|
109
|
+
)
|
110
|
+
|
111
|
+
|
112
|
+
def snr_f_validate(_config: dict) -> None:
|
113
|
+
pass
|
114
|
+
|
115
|
+
|
116
|
+
def snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
117
|
+
from pyaaware import ForwardTransform
|
118
|
+
from pyaaware import feature_forward_transform_config
|
119
|
+
|
120
|
+
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
121
|
+
|
122
|
+
|
123
|
+
def snr_f(mixdb: MixtureDatabase, m_id: int, category: str, config: dict, use_cache: bool = True) -> Truth:
|
124
|
+
"""Frequency domain SNR truth function documentation
|
125
|
+
|
126
|
+
Calculates the true SNR per bin:
|
127
|
+
|
128
|
+
(Ti^2 + Tr^2) / (Ni^2 + Nr^2)
|
129
|
+
|
130
|
+
where T is the target and N is the noise STFT bin values.
|
131
|
+
|
132
|
+
Output shape: [:, bins]
|
133
|
+
"""
|
134
|
+
return _core(
|
135
|
+
mixdb=mixdb,
|
136
|
+
m_id=m_id,
|
137
|
+
category=category,
|
138
|
+
config=config,
|
139
|
+
parameters=snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
140
|
+
mapped=False,
|
141
|
+
snr=True,
|
142
|
+
use_cache=use_cache,
|
143
|
+
)
|
144
|
+
|
145
|
+
|
146
|
+
def mapped_snr_f_validate(config: dict) -> None:
|
147
|
+
if len(config) == 0:
|
148
|
+
raise AttributeError("mapped_snr_f truth function is missing config")
|
149
|
+
|
150
|
+
for parameter in ("snr_db_mean", "snr_db_std"):
|
151
|
+
if parameter not in config:
|
152
|
+
raise AttributeError(f"mapped_snr_f truth function is missing required '{parameter}'")
|
153
|
+
|
154
|
+
|
155
|
+
def mapped_snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
156
|
+
from pyaaware import ForwardTransform
|
157
|
+
from pyaaware import feature_forward_transform_config
|
158
|
+
|
159
|
+
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
160
|
+
|
161
|
+
|
162
|
+
def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, category: str, config: dict, use_cache: bool = True) -> Truth:
|
163
|
+
"""Frequency domain mapped SNR truth function documentation
|
164
|
+
|
165
|
+
Output shape: [:, bins]
|
166
|
+
"""
|
167
|
+
return _core(
|
168
|
+
mixdb=mixdb,
|
169
|
+
m_id=m_id,
|
170
|
+
category=category,
|
171
|
+
config=config,
|
172
|
+
parameters=mapped_snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
173
|
+
mapped=True,
|
174
|
+
snr=True,
|
175
|
+
use_cache=use_cache,
|
176
|
+
)
|
177
|
+
|
178
|
+
|
179
|
+
def energy_t_validate(_config: dict) -> None:
|
180
|
+
pass
|
181
|
+
|
182
|
+
|
183
|
+
def energy_t_parameters(_feature: str, _num_classes: int, _config: dict) -> int:
|
184
|
+
return 1
|
185
|
+
|
186
|
+
|
187
|
+
def energy_t(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
|
188
|
+
"""Time domain energy truth function documentation
|
189
|
+
|
190
|
+
Calculates the true time domain energy of each frame:
|
191
|
+
|
192
|
+
For OLS:
|
193
|
+
sum(x[0:N-1]^2) / N
|
194
|
+
|
195
|
+
For OLA:
|
196
|
+
sum(x[0:R-1]^2) / R
|
197
|
+
|
198
|
+
where x is the target time domain data,
|
199
|
+
N is the size of the transform, and
|
200
|
+
R is the number of new samples in the frame.
|
201
|
+
|
202
|
+
Output shape: [:, 1]
|
203
|
+
|
204
|
+
Note: feature transforms can be defined to use a subset of all bins,
|
205
|
+
i.e., subset of 0:128 for N=256 could be 0:127 or 1:128. energy_t
|
206
|
+
will reflect the total energy over all bins regardless of the feature
|
207
|
+
transform config.
|
208
|
+
"""
|
209
|
+
import torch
|
210
|
+
from pyaaware import ForwardTransform
|
211
|
+
from pyaaware import feature_forward_transform_config
|
212
|
+
|
213
|
+
source_audio = torch.from_numpy(mixdb.mixture_sources(m_id)[category])
|
214
|
+
|
215
|
+
ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
216
|
+
|
217
|
+
frames = ft.frames(source_audio)
|
218
|
+
parameters = energy_f_parameters(mixdb.feature, mixdb.num_classes, _config)
|
219
|
+
if mixdb.mixture(m_id).all_sources[category].snr_gain == 0:
|
220
|
+
return np.zeros((frames, parameters), dtype=np.float32)
|
221
|
+
|
222
|
+
return ft.execute_all(source_audio)[1].numpy()
|
@@ -0,0 +1,48 @@
|
|
1
|
+
from ...datatypes import Truth
|
2
|
+
from ..mixdb import MixtureDatabase
|
3
|
+
|
4
|
+
|
5
|
+
def file_validate(config: dict) -> None:
|
6
|
+
import h5py
|
7
|
+
|
8
|
+
if len(config) == 0:
|
9
|
+
raise AttributeError("file truth function is missing config")
|
10
|
+
|
11
|
+
if "file" not in config:
|
12
|
+
raise AttributeError("file truth function is missing required 'file'")
|
13
|
+
|
14
|
+
with h5py.File(config["file"], "r") as f:
|
15
|
+
if "truth_f" not in f:
|
16
|
+
raise ValueError("Truth file does not contain truth_f dataset")
|
17
|
+
|
18
|
+
|
19
|
+
def file_parameters(_feature: str, _num_classes: int, config: dict) -> int:
|
20
|
+
import h5py
|
21
|
+
import numpy as np
|
22
|
+
|
23
|
+
with h5py.File(config["file"], "r") as f:
|
24
|
+
truth = np.array(f["truth_f"])
|
25
|
+
|
26
|
+
return truth.shape[-1]
|
27
|
+
|
28
|
+
|
29
|
+
def file(mixdb: MixtureDatabase, m_id: int, category: str, config: dict) -> Truth:
|
30
|
+
"""file truth function documentation"""
|
31
|
+
import h5py
|
32
|
+
import numpy as np
|
33
|
+
from pyaaware import feature_inverse_transform_config
|
34
|
+
|
35
|
+
source_audio = mixdb.mixture_sources(m_id)[category]
|
36
|
+
|
37
|
+
frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
|
38
|
+
|
39
|
+
with h5py.File(config["file"], "r") as f:
|
40
|
+
truth = np.array(f["truth_f"])
|
41
|
+
|
42
|
+
if truth.ndim != 2:
|
43
|
+
raise ValueError("Truth file data is not 2 dimensions")
|
44
|
+
|
45
|
+
if truth.shape[0] != len(source_audio) // frame_size:
|
46
|
+
raise ValueError("Truth file does not contain the right amount of frames")
|
47
|
+
|
48
|
+
return truth
|
@@ -0,0 +1,24 @@
|
|
1
|
+
from ...datatypes import Truth
|
2
|
+
from ..mixdb import MixtureDatabase
|
3
|
+
|
4
|
+
|
5
|
+
def metadata_validate(config: dict) -> None:
|
6
|
+
if len(config) == 0:
|
7
|
+
raise AttributeError("metadata truth function is missing config")
|
8
|
+
|
9
|
+
parameters = ["tier"]
|
10
|
+
for parameter in parameters:
|
11
|
+
if parameter not in config:
|
12
|
+
raise AttributeError(f"metadata truth function is missing required '{parameter}'")
|
13
|
+
|
14
|
+
|
15
|
+
def metadata_parameters(_feature: str, _num_classes: int, _config: dict) -> int | None:
|
16
|
+
return None
|
17
|
+
|
18
|
+
|
19
|
+
def metadata(mixdb: MixtureDatabase, m_id: int, category: str, config: dict) -> Truth:
|
20
|
+
"""Metadata truth generation function
|
21
|
+
|
22
|
+
Retrieves metadata from target.
|
23
|
+
"""
|
24
|
+
return mixdb.mixture_speech_metadata(m_id, config["tier"])[category]
|
@@ -0,0 +1,28 @@
|
|
1
|
+
from ...datatypes import Truth
|
2
|
+
from ..mixdb import MixtureDatabase
|
3
|
+
|
4
|
+
|
5
|
+
def metrics_validate(config: dict) -> None:
|
6
|
+
if len(config) == 0:
|
7
|
+
raise AttributeError("metrics truth function is missing config")
|
8
|
+
|
9
|
+
parameters = ["metric"]
|
10
|
+
for parameter in parameters:
|
11
|
+
if parameter not in config:
|
12
|
+
raise AttributeError(f"metrics truth function is missing required '{parameter}'")
|
13
|
+
|
14
|
+
|
15
|
+
def metrics_parameters(_feature: str, _num_classes: int, _config: dict) -> int | None:
|
16
|
+
return None
|
17
|
+
|
18
|
+
|
19
|
+
def metrics(mixdb: MixtureDatabase, m_id: int, category: str, config: dict) -> Truth:
|
20
|
+
"""Metadata truth generation function
|
21
|
+
|
22
|
+
Retrieves metrics from target.
|
23
|
+
"""
|
24
|
+
if not isinstance(config["metric"], list):
|
25
|
+
m = [config["metric"]]
|
26
|
+
else:
|
27
|
+
m = config["metric"]
|
28
|
+
return mixdb.mixture_metrics(m_id, m)[m[0]][category]
|
@@ -0,0 +1,18 @@
|
|
1
|
+
from ...datatypes import Truth
|
2
|
+
from ..mixdb import MixtureDatabase
|
3
|
+
|
4
|
+
|
5
|
+
def phoneme_validate(_config: dict) -> None:
|
6
|
+
raise NotImplementedError("Truth function phoneme is not supported yet")
|
7
|
+
|
8
|
+
|
9
|
+
def phoneme_parameters(_feature: str, _num_classes: int, _config: dict) -> int:
|
10
|
+
raise NotImplementedError("Truth function phoneme is not supported yet")
|
11
|
+
|
12
|
+
|
13
|
+
def phoneme(_mixdb: MixtureDatabase, _m_id: int, _category: str, _config: dict) -> Truth:
|
14
|
+
"""Read in .txt transcript and run a Python function to generate text grid data
|
15
|
+
(indicating which phonemes are active). Then generate truth based on this data and put
|
16
|
+
in the correct classes based on the index in the config.
|
17
|
+
"""
|
18
|
+
raise NotImplementedError("Truth function phoneme is not supported yet")
|