sonusai 0.19.6__py3-none-any.whl → 0.19.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +1 -1
- sonusai/aawscd_probwrite.py +1 -1
- sonusai/calc_metric_spenh.py +1 -1
- sonusai/genft.py +29 -14
- sonusai/genmetrics.py +60 -42
- sonusai/genmix.py +41 -29
- sonusai/genmixdb.py +54 -62
- sonusai/metrics/calc_class_weights.py +1 -3
- sonusai/metrics/calc_optimal_thresholds.py +2 -2
- sonusai/metrics/calc_phase_distance.py +1 -1
- sonusai/metrics/calc_speech.py +6 -6
- sonusai/metrics/class_summary.py +6 -15
- sonusai/metrics/confusion_matrix_summary.py +11 -27
- sonusai/metrics/one_hot.py +3 -3
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/mixture/__init__.py +2 -17
- sonusai/mixture/augmentation.py +5 -6
- sonusai/mixture/class_count.py +1 -1
- sonusai/mixture/config.py +36 -46
- sonusai/mixture/data_io.py +30 -1
- sonusai/mixture/datatypes.py +29 -40
- sonusai/mixture/db_datatypes.py +1 -1
- sonusai/mixture/feature.py +3 -23
- sonusai/mixture/generation.py +202 -235
- sonusai/mixture/helpers.py +29 -187
- sonusai/mixture/mixdb.py +386 -159
- sonusai/mixture/soundfile_audio.py +1 -1
- sonusai/mixture/sox_audio.py +4 -4
- sonusai/mixture/sox_augmentation.py +1 -1
- sonusai/mixture/target_class_balancing.py +9 -11
- sonusai/mixture/targets.py +23 -20
- sonusai/mixture/truth.py +21 -34
- sonusai/mixture/truth_functions/__init__.py +6 -0
- sonusai/mixture/truth_functions/crm.py +51 -37
- sonusai/mixture/truth_functions/energy.py +95 -50
- sonusai/mixture/truth_functions/file.py +12 -8
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +4 -5
- sonusai/mixture/truth_functions/sed.py +32 -23
- sonusai/mixture/truth_functions/target.py +62 -29
- sonusai/mkwav.py +20 -19
- sonusai/queries/queries.py +9 -15
- sonusai/speech/l2arctic.py +6 -2
- sonusai/summarize_metric_spenh.py +1 -1
- sonusai/utils/__init__.py +1 -0
- sonusai/utils/asr_functions/aaware_whisper.py +1 -1
- sonusai/utils/audio_devices.py +27 -18
- sonusai/utils/docstring.py +6 -3
- sonusai/utils/energy_f.py +5 -3
- sonusai/utils/human_readable_size.py +6 -6
- sonusai/utils/load_object.py +15 -0
- sonusai/utils/onnx_utils.py +2 -2
- sonusai/utils/print_mixture_details.py +3 -3
- {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/METADATA +2 -2
- {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/RECORD +58 -56
- sonusai/mixture/truth_functions/datatypes.py +0 -37
- {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/WHEEL +0 -0
- {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/entry_points.txt +0 -0
@@ -32,7 +32,7 @@ def _raw_read(name: str | Path) -> tuple[AudioT, int]:
|
|
32
32
|
else:
|
33
33
|
raise OSError(f"Error reading {name}: {e}") from e
|
34
34
|
|
35
|
-
return np.squeeze(raw[:, 0]), sample_rate
|
35
|
+
return np.squeeze(raw[:, 0].astype(np.float32)), sample_rate
|
36
36
|
|
37
37
|
|
38
38
|
def get_sample_rate(name: str | Path) -> int:
|
sonusai/mixture/sox_audio.py
CHANGED
@@ -207,7 +207,7 @@ class Transformer(SoxTransformer):
|
|
207
207
|
|
208
208
|
return self
|
209
209
|
|
210
|
-
def build(
|
210
|
+
def build( # pyright: ignore [reportIncompatibleMethodOverride]
|
211
211
|
self,
|
212
212
|
input_filepath: str | Path | None = None,
|
213
213
|
output_filepath: str | Path | None = None,
|
@@ -320,11 +320,11 @@ class Transformer(SoxTransformer):
|
|
320
320
|
logger.info("Created %s with effects: %s", output_filepath, " ".join(self.effects_log))
|
321
321
|
|
322
322
|
if return_output:
|
323
|
-
return status, out, err
|
323
|
+
return status, out, err # pyright: ignore [reportReturnType]
|
324
324
|
|
325
325
|
return True, None, None
|
326
326
|
|
327
|
-
def build_array(
|
327
|
+
def build_array( # pyright: ignore [reportIncompatibleMethodOverride]
|
328
328
|
self,
|
329
329
|
input_filepath: str | Path | None = None,
|
330
330
|
input_array: np.ndarray | None = None,
|
@@ -465,7 +465,7 @@ class Transformer(SoxTransformer):
|
|
465
465
|
if status != 0:
|
466
466
|
raise SoxError(f"Stdout: {out}\nStderr: {err}")
|
467
467
|
|
468
|
-
out = np.frombuffer(out, dtype=encoding_out)
|
468
|
+
out = np.frombuffer(out, dtype=encoding_out) # pyright: ignore [reportArgumentType, reportCallIssue]
|
469
469
|
if output_format["channels"] > 1:
|
470
470
|
out = out.reshape(
|
471
471
|
(output_format["channels"], int(len(out) / output_format["channels"])),
|
@@ -118,7 +118,7 @@ def apply_impulse_response(audio: AudioT, ir: ImpulseResponseData) -> AudioT:
|
|
118
118
|
# Apply IR and convert back to global sample rate
|
119
119
|
tfm = Transformer()
|
120
120
|
tfm.set_output_format(rate=SAMPLE_RATE)
|
121
|
-
tfm.fir(coefficients=temp.name)
|
121
|
+
tfm.fir(coefficients=temp.name) # pyright: ignore [reportArgumentType]
|
122
122
|
try:
|
123
123
|
audio_out = tfm.build_array(input_array=audio_out, sample_rate_in=ir.sample_rate)
|
124
124
|
except Exception as e:
|
@@ -1,19 +1,17 @@
|
|
1
1
|
from sonusai.mixture.datatypes import AugmentationRule
|
2
|
-
from sonusai.mixture.datatypes import
|
3
|
-
from sonusai.mixture.datatypes import AugmentedTargets
|
2
|
+
from sonusai.mixture.datatypes import AugmentedTarget
|
4
3
|
from sonusai.mixture.datatypes import TargetFile
|
5
|
-
from sonusai.mixture.datatypes import TargetFiles
|
6
4
|
|
7
5
|
|
8
6
|
def balance_targets(
|
9
|
-
augmented_targets:
|
10
|
-
targets:
|
11
|
-
target_augmentations:
|
7
|
+
augmented_targets: list[AugmentedTarget],
|
8
|
+
targets: list[TargetFile],
|
9
|
+
target_augmentations: list[AugmentationRule],
|
12
10
|
class_balancing_augmentation: AugmentationRule,
|
13
11
|
num_classes: int,
|
14
12
|
num_ir: int,
|
15
13
|
mixups: list[int] | None = None,
|
16
|
-
) -> tuple[
|
14
|
+
) -> tuple[list[AugmentedTarget], list[AugmentationRule]]:
|
17
15
|
import math
|
18
16
|
|
19
17
|
from .augmentation import get_mixups
|
@@ -64,15 +62,15 @@ def balance_targets(
|
|
64
62
|
|
65
63
|
|
66
64
|
def _get_unused_balancing_augmentation(
|
67
|
-
augmented_targets:
|
68
|
-
targets:
|
69
|
-
target_augmentations:
|
65
|
+
augmented_targets: list[AugmentedTarget],
|
66
|
+
targets: list[TargetFile],
|
67
|
+
target_augmentations: list[AugmentationRule],
|
70
68
|
class_balancing_augmentation: AugmentationRule,
|
71
69
|
target_id: int,
|
72
70
|
mixup: int,
|
73
71
|
num_ir: int,
|
74
72
|
first_cba_id: int,
|
75
|
-
) -> tuple[int,
|
73
|
+
) -> tuple[int, list[AugmentationRule]]:
|
76
74
|
"""Get an unused balancing augmentation for a given target file index"""
|
77
75
|
from dataclasses import asdict
|
78
76
|
|
sonusai/mixture/targets.py
CHANGED
@@ -1,21 +1,20 @@
|
|
1
|
-
from sonusai.mixture.datatypes import
|
1
|
+
from sonusai.mixture.datatypes import AugmentationRule
|
2
2
|
from sonusai.mixture.datatypes import AugmentedTarget
|
3
|
-
from sonusai.mixture.datatypes import
|
4
|
-
from sonusai.mixture.datatypes import TargetFiles
|
3
|
+
from sonusai.mixture.datatypes import TargetFile
|
5
4
|
|
6
5
|
|
7
6
|
def get_augmented_targets(
|
8
|
-
target_files:
|
9
|
-
target_augmentations:
|
7
|
+
target_files: list[TargetFile],
|
8
|
+
target_augmentations: list[AugmentationRule],
|
10
9
|
mixups: list[int] | None = None,
|
11
|
-
) ->
|
10
|
+
) -> list[AugmentedTarget]:
|
12
11
|
from .augmentation import get_augmentation_indices_for_mixup
|
13
12
|
from .augmentation import get_mixups
|
14
13
|
|
15
14
|
if mixups is None:
|
16
15
|
mixups = get_mixups(target_augmentations)
|
17
16
|
|
18
|
-
augmented_targets:
|
17
|
+
augmented_targets: list[AugmentedTarget] = []
|
19
18
|
for mixup in mixups:
|
20
19
|
augmentation_indices = get_augmentation_indices_for_mixup(target_augmentations, mixup)
|
21
20
|
for target_index in range(len(target_files)):
|
@@ -30,15 +29,17 @@ def get_augmented_targets(
|
|
30
29
|
return augmented_targets
|
31
30
|
|
32
31
|
|
33
|
-
def get_class_index_for_augmented_target(augmented_target: AugmentedTarget, targets:
|
32
|
+
def get_class_index_for_augmented_target(augmented_target: AugmentedTarget, targets: list[TargetFile]) -> list[int]:
|
34
33
|
return targets[augmented_target.target_id].class_indices
|
35
34
|
|
36
35
|
|
37
|
-
def get_mixup_for_augmented_target(augmented_target: AugmentedTarget, augmentations:
|
36
|
+
def get_mixup_for_augmented_target(augmented_target: AugmentedTarget, augmentations: list[AugmentationRule]) -> int:
|
38
37
|
return augmentations[augmented_target.target_augmentation_id].mixup
|
39
38
|
|
40
39
|
|
41
|
-
def get_target_ids_for_class_index(
|
40
|
+
def get_target_ids_for_class_index(
|
41
|
+
targets: list[TargetFile], class_index: int, allow_multiple: bool = False
|
42
|
+
) -> list[int]:
|
42
43
|
"""Get a list of target indices containing the given class index.
|
43
44
|
|
44
45
|
If allow_multiple is True, then include targets that contain multiple class indices.
|
@@ -55,9 +56,9 @@ def get_target_ids_for_class_index(targets: TargetFiles, class_index: int, allow
|
|
55
56
|
|
56
57
|
|
57
58
|
def get_augmented_target_ids_for_class_index(
|
58
|
-
augmented_targets:
|
59
|
-
targets:
|
60
|
-
augmentations:
|
59
|
+
augmented_targets: list[AugmentedTarget],
|
60
|
+
targets: list[TargetFile],
|
61
|
+
augmentations: list[AugmentationRule],
|
61
62
|
class_index: int,
|
62
63
|
mixup: int,
|
63
64
|
allow_multiple: bool = False,
|
@@ -79,9 +80,9 @@ def get_augmented_target_ids_for_class_index(
|
|
79
80
|
|
80
81
|
|
81
82
|
def get_augmented_target_ids_by_class(
|
82
|
-
augmented_targets:
|
83
|
-
targets:
|
84
|
-
target_augmentations:
|
83
|
+
augmented_targets: list[AugmentedTarget],
|
84
|
+
targets: list[TargetFile],
|
85
|
+
target_augmentations: list[AugmentationRule],
|
85
86
|
mixup: int,
|
86
87
|
num_classes: int,
|
87
88
|
) -> list[list[int]]:
|
@@ -99,7 +100,9 @@ def get_augmented_target_ids_by_class(
|
|
99
100
|
return indices
|
100
101
|
|
101
102
|
|
102
|
-
def get_target_augmentations_for_mixup(
|
103
|
+
def get_target_augmentations_for_mixup(
|
104
|
+
target_augmentations: list[AugmentationRule], mixup: int
|
105
|
+
) -> list[AugmentationRule]:
|
103
106
|
"""Get target augmentations for a given mixup value
|
104
107
|
|
105
108
|
:param target_augmentations: List of target augmentation rules
|
@@ -110,9 +113,9 @@ def get_target_augmentations_for_mixup(target_augmentations: AugmentationRules,
|
|
110
113
|
|
111
114
|
|
112
115
|
def get_augmented_target_ids_for_mixup(
|
113
|
-
augmented_targets:
|
114
|
-
targets:
|
115
|
-
target_augmentations:
|
116
|
+
augmented_targets: list[AugmentedTarget],
|
117
|
+
targets: list[TargetFile],
|
118
|
+
target_augmentations: list[AugmentationRule],
|
116
119
|
mixup: int,
|
117
120
|
num_classes: int,
|
118
121
|
) -> list[list[int]]:
|
sonusai/mixture/truth.py
CHANGED
@@ -1,39 +1,26 @@
|
|
1
|
-
from sonusai.mixture
|
2
|
-
from sonusai.mixture
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
def truth_function(
|
8
|
-
target_audio: AudioT,
|
9
|
-
noise_audio: AudioT,
|
10
|
-
mixture_audio: AudioT,
|
11
|
-
config: TruthConfig,
|
12
|
-
feature: str,
|
13
|
-
num_classes: int,
|
14
|
-
class_indices: list[int],
|
15
|
-
target_gain: float,
|
16
|
-
) -> Truth:
|
1
|
+
from sonusai.mixture import MixtureDatabase
|
2
|
+
from sonusai.mixture import Truth
|
3
|
+
|
4
|
+
|
5
|
+
def truth_function(mixdb: MixtureDatabase, m_id: int) -> list[Truth]:
|
6
|
+
from sonusai.mixture import TruthDict
|
17
7
|
from sonusai.mixture import truth_functions
|
18
8
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
raise AttributeError(f"Unsupported truth function: {config.function}") from e
|
35
|
-
except Exception as e:
|
36
|
-
raise RuntimeError(f"Error in truth function '{config.function}': {e}") from e
|
9
|
+
result: list[Truth] = []
|
10
|
+
for target_index in range(len(mixdb.mixture(m_id).targets)):
|
11
|
+
truth: TruthDict = {}
|
12
|
+
target_file = mixdb.target_file(mixdb.mixture(m_id).targets[target_index].file_id)
|
13
|
+
for name, config in target_file.truth_configs.items():
|
14
|
+
try:
|
15
|
+
truth[name] = getattr(truth_functions, config.function)(mixdb, m_id, target_index, config.config)
|
16
|
+
except AttributeError as e:
|
17
|
+
raise AttributeError(f"Unsupported truth function: {config.function}") from e
|
18
|
+
except Exception as e:
|
19
|
+
raise RuntimeError(f"Error in truth function '{config.function}': {e}") from e
|
20
|
+
|
21
|
+
result.append(truth)
|
22
|
+
|
23
|
+
return result
|
37
24
|
|
38
25
|
|
39
26
|
def get_truth_indices_for_mixid(mixdb: MixtureDatabase, mixid: int) -> list[int]:
|
@@ -22,6 +22,12 @@ from .energy import snr_f_validate
|
|
22
22
|
from .file import file
|
23
23
|
from .file import file_parameters
|
24
24
|
from .file import file_validate
|
25
|
+
from .metadata import metadata
|
26
|
+
from .metadata import metadata_parameters
|
27
|
+
from .metadata import metadata_validate
|
28
|
+
from .metrics import metrics
|
29
|
+
from .metrics import metrics_parameters
|
30
|
+
from .metrics import metrics_validate
|
25
31
|
from .phoneme import phoneme
|
26
32
|
from .phoneme import phoneme_parameters
|
27
33
|
from .phoneme import phoneme_validate
|
@@ -1,22 +1,32 @@
|
|
1
|
-
from sonusai.mixture
|
2
|
-
from sonusai.mixture
|
3
|
-
from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
|
1
|
+
from sonusai.mixture import MixtureDatabase
|
2
|
+
from sonusai.mixture import Truth
|
4
3
|
|
5
4
|
|
6
|
-
def _core(
|
5
|
+
def _core(mixdb: MixtureDatabase, m_id: int, target_index: int, parameters: int, polar: bool) -> Truth:
|
7
6
|
import numpy as np
|
7
|
+
import torch
|
8
|
+
from pyaaware import ForwardTransform
|
9
|
+
from pyaaware import feature_forward_transform_config
|
10
|
+
from pyaaware import feature_inverse_transform_config
|
8
11
|
|
9
|
-
|
10
|
-
|
12
|
+
target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
|
13
|
+
t_ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
14
|
+
n_ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
11
15
|
|
12
|
-
frames =
|
13
|
-
|
16
|
+
frames = t_ft.frames(target_audio)
|
17
|
+
if mixdb.mixture(m_id).target_gain(target_index) == 0:
|
18
|
+
return np.zeros((frames, parameters), dtype=np.float32)
|
19
|
+
|
20
|
+
noise_audio = torch.from_numpy(mixdb.mixture_noise(m_id))
|
21
|
+
|
22
|
+
frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
|
23
|
+
|
24
|
+
frames = len(target_audio) // frame_size
|
25
|
+
truth = np.empty((frames, t_ft.bins * 2), dtype=np.float32)
|
14
26
|
for frame in range(frames):
|
15
|
-
offset = frame *
|
16
|
-
target_f =
|
17
|
-
|
18
|
-
)
|
19
|
-
noise_f = config.noise_fft.execute(data.noise_audio[offset : offset + config.frame_size]).astype(np.complex64)
|
27
|
+
offset = frame * frame_size
|
28
|
+
target_f = t_ft.execute(target_audio[offset : offset + frame_size])[0].numpy().astype(np.complex64)
|
29
|
+
noise_f = n_ft.execute(noise_audio[offset : offset + frame_size])[0].numpy().astype(np.complex64)
|
20
30
|
mixture_f = target_f + noise_f
|
21
31
|
|
22
32
|
crm_data = np.empty(target_f.shape, dtype=np.complex64)
|
@@ -31,8 +41,8 @@ def _core(data: TruthFunctionData, config: TruthFunctionConfig, polar: bool) ->
|
|
31
41
|
else:
|
32
42
|
crm_data[it.multi_index] = num / den
|
33
43
|
|
34
|
-
truth[frame, :
|
35
|
-
truth[frame,
|
44
|
+
truth[frame, : t_ft.bins] = np.absolute(crm_data) if polar else np.real(crm_data)
|
45
|
+
truth[frame, t_ft.bins :] = np.angle(crm_data) if polar else np.imag(crm_data)
|
36
46
|
|
37
47
|
return truth
|
38
48
|
|
@@ -41,11 +51,14 @@ def crm_validate(_config: dict) -> None:
|
|
41
51
|
pass
|
42
52
|
|
43
53
|
|
44
|
-
def crm_parameters(
|
45
|
-
|
54
|
+
def crm_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
55
|
+
from pyaaware import ForwardTransform
|
56
|
+
from pyaaware import feature_forward_transform_config
|
46
57
|
|
58
|
+
return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
|
47
59
|
|
48
|
-
|
60
|
+
|
61
|
+
def crm(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
|
49
62
|
"""Complex ratio mask truth generation function
|
50
63
|
|
51
64
|
Calculates the true complex ratio mask (CRM) truth which is a complex number
|
@@ -55,25 +68,27 @@ def crm(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
|
55
68
|
|
56
69
|
Output shape: [:, 2 * bins]
|
57
70
|
"""
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
return _core(data=data, config=config, polar=False)
|
71
|
+
return _core(
|
72
|
+
mixdb=mixdb,
|
73
|
+
m_id=m_id,
|
74
|
+
target_index=target_index,
|
75
|
+
parameters=crm_parameters(mixdb.feature, mixdb.num_classes, _config),
|
76
|
+
polar=False,
|
77
|
+
)
|
66
78
|
|
67
79
|
|
68
80
|
def crmp_validate(_config: dict) -> None:
|
69
81
|
pass
|
70
82
|
|
71
83
|
|
72
|
-
def crmp_parameters(
|
73
|
-
|
84
|
+
def crmp_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
85
|
+
from pyaaware import ForwardTransform
|
86
|
+
from pyaaware import feature_forward_transform_config
|
74
87
|
|
88
|
+
return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
|
75
89
|
|
76
|
-
|
90
|
+
|
91
|
+
def crmp(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
|
77
92
|
"""Complex ratio mask polar truth generation function
|
78
93
|
|
79
94
|
Same as the crm function except the results are magnitude and phase
|
@@ -81,11 +96,10 @@ def crmp(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
|
81
96
|
|
82
97
|
Output shape: [:, bins]
|
83
98
|
"""
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
return _core(data=data, config=config, polar=True)
|
99
|
+
return _core(
|
100
|
+
mixdb=mixdb,
|
101
|
+
m_id=m_id,
|
102
|
+
target_index=target_index,
|
103
|
+
parameters=crmp_parameters(mixdb.feature, mixdb.num_classes, _config),
|
104
|
+
polar=True,
|
105
|
+
)
|
@@ -1,20 +1,44 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
|
-
from sonusai.mixture
|
4
|
-
from sonusai.mixture
|
5
|
-
from sonusai.
|
3
|
+
from sonusai.mixture import MixtureDatabase
|
4
|
+
from sonusai.mixture import Truth
|
5
|
+
from sonusai.utils import load_object
|
6
|
+
|
7
|
+
|
8
|
+
def _core(
|
9
|
+
mixdb: MixtureDatabase,
|
10
|
+
m_id: int,
|
11
|
+
target_index: int,
|
12
|
+
config: dict,
|
13
|
+
parameters: int,
|
14
|
+
mapped: bool,
|
15
|
+
snr: bool,
|
16
|
+
) -> Truth:
|
17
|
+
from os.path import join
|
6
18
|
|
19
|
+
import torch
|
20
|
+
from pyaaware import ForwardTransform
|
21
|
+
from pyaaware import feature_forward_transform_config
|
7
22
|
|
8
|
-
def _core(data: TruthFunctionData, config: TruthFunctionConfig, mapped: bool, snr: bool) -> Truth:
|
9
23
|
from sonusai.utils import compute_energy_f
|
10
24
|
|
11
|
-
|
25
|
+
target_audio = mixdb.mixture_targets(m_id)[target_index]
|
26
|
+
ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
27
|
+
|
28
|
+
frames = ft.frames(torch.from_numpy(target_audio))
|
29
|
+
|
30
|
+
if mixdb.mixture(m_id).target_gain(target_index) == 0:
|
31
|
+
return np.zeros((frames, parameters), dtype=np.float32)
|
32
|
+
|
33
|
+
noise_audio = mixdb.mixture_noise(m_id)
|
34
|
+
|
35
|
+
target_energy = compute_energy_f(time_domain=target_audio, transform=ft)
|
12
36
|
noise_energy = None
|
13
37
|
if snr:
|
14
|
-
noise_energy = compute_energy_f(time_domain=
|
38
|
+
noise_energy = compute_energy_f(time_domain=noise_audio, transform=ft)
|
15
39
|
|
16
40
|
frames = len(target_energy)
|
17
|
-
truth = np.empty((frames,
|
41
|
+
truth = np.empty((frames, ft.bins), dtype=np.float32)
|
18
42
|
for frame in range(frames):
|
19
43
|
tmp = target_energy[frame]
|
20
44
|
|
@@ -26,7 +50,9 @@ def _core(data: TruthFunctionData, config: TruthFunctionConfig, mapped: bool, sn
|
|
26
50
|
tmp = np.nan_to_num(tmp, nan=-np.inf, posinf=np.inf, neginf=-np.inf)
|
27
51
|
|
28
52
|
if mapped:
|
29
|
-
|
53
|
+
snr_db_mean = load_object(join(mixdb.location, config["snr_db_mean"]))
|
54
|
+
snr_db_std = load_object(join(mixdb.location, config["snr_db_std"]))
|
55
|
+
tmp = _calculate_mapped_snr_f(tmp, snr_db_mean, snr_db_std)
|
30
56
|
|
31
57
|
truth[frame] = tmp
|
32
58
|
|
@@ -52,11 +78,14 @@ def energy_f_validate(_config: dict) -> None:
|
|
52
78
|
pass
|
53
79
|
|
54
80
|
|
55
|
-
def energy_f_parameters(
|
56
|
-
|
81
|
+
def energy_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
82
|
+
from pyaaware import ForwardTransform
|
83
|
+
from pyaaware import feature_forward_transform_config
|
84
|
+
|
85
|
+
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
57
86
|
|
58
87
|
|
59
|
-
def energy_f(
|
88
|
+
def energy_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
|
60
89
|
"""Frequency domain energy truth generation function
|
61
90
|
|
62
91
|
Calculates the true energy per bin:
|
@@ -67,23 +96,29 @@ def energy_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
|
67
96
|
|
68
97
|
Output shape: [:, bins]
|
69
98
|
"""
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
99
|
+
return _core(
|
100
|
+
mixdb=mixdb,
|
101
|
+
m_id=m_id,
|
102
|
+
target_index=target_index,
|
103
|
+
config=config,
|
104
|
+
parameters=energy_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
105
|
+
mapped=False,
|
106
|
+
snr=False,
|
107
|
+
)
|
76
108
|
|
77
109
|
|
78
110
|
def snr_f_validate(_config: dict) -> None:
|
79
111
|
pass
|
80
112
|
|
81
113
|
|
82
|
-
def snr_f_parameters(
|
83
|
-
|
114
|
+
def snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
115
|
+
from pyaaware import ForwardTransform
|
116
|
+
from pyaaware import feature_forward_transform_config
|
84
117
|
|
118
|
+
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
85
119
|
|
86
|
-
|
120
|
+
|
121
|
+
def snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
|
87
122
|
"""Frequency domain SNR truth function documentation
|
88
123
|
|
89
124
|
Calculates the true SNR per bin:
|
@@ -94,54 +129,58 @@ def snr_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
|
94
129
|
|
95
130
|
Output shape: [:, bins]
|
96
131
|
"""
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
132
|
+
return _core(
|
133
|
+
mixdb=mixdb,
|
134
|
+
m_id=m_id,
|
135
|
+
target_index=target_index,
|
136
|
+
config=config,
|
137
|
+
parameters=snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
138
|
+
mapped=False,
|
139
|
+
snr=True,
|
140
|
+
)
|
141
|
+
|
142
|
+
|
143
|
+
def mapped_snr_f_validate(config: dict) -> None:
|
144
|
+
if len(config) == 0:
|
107
145
|
raise AttributeError("mapped_snr_f truth function is missing config")
|
108
146
|
|
109
147
|
for parameter in ("snr_db_mean", "snr_db_std"):
|
110
|
-
if parameter not in config
|
148
|
+
if parameter not in config:
|
111
149
|
raise AttributeError(f"mapped_snr_f truth function is missing required '{parameter}'")
|
112
150
|
|
113
|
-
if len(config.config[parameter]) != config.target_fft.bins:
|
114
|
-
raise ValueError(
|
115
|
-
f"mapped_snr_f truth function '{parameter}' does not have {config.target_fft.bins} elements"
|
116
|
-
)
|
117
151
|
|
152
|
+
def mapped_snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
153
|
+
from pyaaware import ForwardTransform
|
154
|
+
from pyaaware import feature_forward_transform_config
|
118
155
|
|
119
|
-
|
120
|
-
return config.target_fft.bins
|
156
|
+
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
121
157
|
|
122
158
|
|
123
|
-
def mapped_snr_f(
|
159
|
+
def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
|
124
160
|
"""Frequency domain mapped SNR truth function documentation
|
125
161
|
|
126
162
|
Output shape: [:, bins]
|
127
163
|
"""
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
164
|
+
return _core(
|
165
|
+
mixdb=mixdb,
|
166
|
+
m_id=m_id,
|
167
|
+
target_index=target_index,
|
168
|
+
config=config,
|
169
|
+
parameters=mapped_snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
170
|
+
mapped=True,
|
171
|
+
snr=True,
|
172
|
+
)
|
134
173
|
|
135
174
|
|
136
175
|
def energy_t_validate(_config: dict) -> None:
|
137
176
|
pass
|
138
177
|
|
139
178
|
|
140
|
-
def energy_t_parameters(_config:
|
179
|
+
def energy_t_parameters(_feature: str, _num_classes: int, _config: dict) -> int:
|
141
180
|
return 1
|
142
181
|
|
143
182
|
|
144
|
-
def energy_t(
|
183
|
+
def energy_t(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
|
145
184
|
"""Time domain energy truth function documentation
|
146
185
|
|
147
186
|
Calculates the true time domain energy of each frame:
|
@@ -164,10 +203,16 @@ def energy_t(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
|
164
203
|
transform config.
|
165
204
|
"""
|
166
205
|
import torch
|
206
|
+
from pyaaware import ForwardTransform
|
207
|
+
from pyaaware import feature_forward_transform_config
|
208
|
+
|
209
|
+
target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
|
210
|
+
|
211
|
+
ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
167
212
|
|
168
|
-
frames =
|
169
|
-
parameters =
|
170
|
-
if
|
213
|
+
frames = ft.frames(target_audio)
|
214
|
+
parameters = energy_f_parameters(mixdb.feature, mixdb.num_classes, _config)
|
215
|
+
if mixdb.mixture(m_id).target_gain(target_index) == 0:
|
171
216
|
return np.zeros((frames, parameters), dtype=np.float32)
|
172
217
|
|
173
|
-
return
|
218
|
+
return ft.execute_all(target_audio)[1].numpy()
|