sonusai 0.19.6__py3-none-any.whl → 0.19.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +1 -1
- sonusai/aawscd_probwrite.py +1 -1
- sonusai/calc_metric_spenh.py +1 -1
- sonusai/genft.py +29 -14
- sonusai/genmetrics.py +60 -42
- sonusai/genmix.py +41 -29
- sonusai/genmixdb.py +56 -64
- sonusai/metrics/calc_class_weights.py +1 -3
- sonusai/metrics/calc_optimal_thresholds.py +2 -2
- sonusai/metrics/calc_phase_distance.py +1 -1
- sonusai/metrics/calc_speech.py +6 -6
- sonusai/metrics/class_summary.py +6 -15
- sonusai/metrics/confusion_matrix_summary.py +11 -27
- sonusai/metrics/one_hot.py +3 -3
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/mixture/__init__.py +2 -17
- sonusai/mixture/augmentation.py +5 -6
- sonusai/mixture/class_count.py +1 -1
- sonusai/mixture/config.py +36 -46
- sonusai/mixture/data_io.py +30 -1
- sonusai/mixture/datatypes.py +29 -40
- sonusai/mixture/db_datatypes.py +1 -1
- sonusai/mixture/feature.py +3 -23
- sonusai/mixture/generation.py +161 -204
- sonusai/mixture/helpers.py +29 -187
- sonusai/mixture/mixdb.py +386 -159
- sonusai/mixture/soundfile_audio.py +1 -1
- sonusai/mixture/sox_audio.py +4 -4
- sonusai/mixture/sox_augmentation.py +1 -1
- sonusai/mixture/target_class_balancing.py +9 -11
- sonusai/mixture/targets.py +23 -20
- sonusai/mixture/torchaudio_audio.py +18 -7
- sonusai/mixture/torchaudio_augmentation.py +3 -4
- sonusai/mixture/truth.py +21 -34
- sonusai/mixture/truth_functions/__init__.py +6 -0
- sonusai/mixture/truth_functions/crm.py +51 -37
- sonusai/mixture/truth_functions/energy.py +95 -50
- sonusai/mixture/truth_functions/file.py +12 -8
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +4 -5
- sonusai/mixture/truth_functions/sed.py +32 -23
- sonusai/mixture/truth_functions/target.py +62 -29
- sonusai/mkwav.py +20 -19
- sonusai/queries/queries.py +9 -15
- sonusai/speech/l2arctic.py +6 -2
- sonusai/summarize_metric_spenh.py +1 -1
- sonusai/utils/__init__.py +1 -0
- sonusai/utils/asr_functions/aaware_whisper.py +1 -1
- sonusai/utils/audio_devices.py +27 -18
- sonusai/utils/docstring.py +6 -3
- sonusai/utils/energy_f.py +5 -3
- sonusai/utils/human_readable_size.py +6 -6
- sonusai/utils/load_object.py +15 -0
- sonusai/utils/onnx_utils.py +2 -2
- sonusai/utils/print_mixture_details.py +3 -3
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/METADATA +2 -2
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/RECORD +60 -58
- sonusai/mixture/truth_functions/datatypes.py +0 -37
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/WHEEL +0 -0
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/entry_points.txt +0 -0
@@ -32,7 +32,7 @@ def _raw_read(name: str | Path) -> tuple[AudioT, int]:
|
|
32
32
|
else:
|
33
33
|
raise OSError(f"Error reading {name}: {e}") from e
|
34
34
|
|
35
|
-
return np.squeeze(raw[:, 0]), sample_rate
|
35
|
+
return np.squeeze(raw[:, 0].astype(np.float32)), sample_rate
|
36
36
|
|
37
37
|
|
38
38
|
def get_sample_rate(name: str | Path) -> int:
|
sonusai/mixture/sox_audio.py
CHANGED
@@ -207,7 +207,7 @@ class Transformer(SoxTransformer):
|
|
207
207
|
|
208
208
|
return self
|
209
209
|
|
210
|
-
def build(
|
210
|
+
def build( # pyright: ignore [reportIncompatibleMethodOverride]
|
211
211
|
self,
|
212
212
|
input_filepath: str | Path | None = None,
|
213
213
|
output_filepath: str | Path | None = None,
|
@@ -320,11 +320,11 @@ class Transformer(SoxTransformer):
|
|
320
320
|
logger.info("Created %s with effects: %s", output_filepath, " ".join(self.effects_log))
|
321
321
|
|
322
322
|
if return_output:
|
323
|
-
return status, out, err
|
323
|
+
return status, out, err # pyright: ignore [reportReturnType]
|
324
324
|
|
325
325
|
return True, None, None
|
326
326
|
|
327
|
-
def build_array(
|
327
|
+
def build_array( # pyright: ignore [reportIncompatibleMethodOverride]
|
328
328
|
self,
|
329
329
|
input_filepath: str | Path | None = None,
|
330
330
|
input_array: np.ndarray | None = None,
|
@@ -465,7 +465,7 @@ class Transformer(SoxTransformer):
|
|
465
465
|
if status != 0:
|
466
466
|
raise SoxError(f"Stdout: {out}\nStderr: {err}")
|
467
467
|
|
468
|
-
out = np.frombuffer(out, dtype=encoding_out)
|
468
|
+
out = np.frombuffer(out, dtype=encoding_out) # pyright: ignore [reportArgumentType, reportCallIssue]
|
469
469
|
if output_format["channels"] > 1:
|
470
470
|
out = out.reshape(
|
471
471
|
(output_format["channels"], int(len(out) / output_format["channels"])),
|
@@ -118,7 +118,7 @@ def apply_impulse_response(audio: AudioT, ir: ImpulseResponseData) -> AudioT:
|
|
118
118
|
# Apply IR and convert back to global sample rate
|
119
119
|
tfm = Transformer()
|
120
120
|
tfm.set_output_format(rate=SAMPLE_RATE)
|
121
|
-
tfm.fir(coefficients=temp.name)
|
121
|
+
tfm.fir(coefficients=temp.name) # pyright: ignore [reportArgumentType]
|
122
122
|
try:
|
123
123
|
audio_out = tfm.build_array(input_array=audio_out, sample_rate_in=ir.sample_rate)
|
124
124
|
except Exception as e:
|
@@ -1,19 +1,17 @@
|
|
1
1
|
from sonusai.mixture.datatypes import AugmentationRule
|
2
|
-
from sonusai.mixture.datatypes import
|
3
|
-
from sonusai.mixture.datatypes import AugmentedTargets
|
2
|
+
from sonusai.mixture.datatypes import AugmentedTarget
|
4
3
|
from sonusai.mixture.datatypes import TargetFile
|
5
|
-
from sonusai.mixture.datatypes import TargetFiles
|
6
4
|
|
7
5
|
|
8
6
|
def balance_targets(
|
9
|
-
augmented_targets:
|
10
|
-
targets:
|
11
|
-
target_augmentations:
|
7
|
+
augmented_targets: list[AugmentedTarget],
|
8
|
+
targets: list[TargetFile],
|
9
|
+
target_augmentations: list[AugmentationRule],
|
12
10
|
class_balancing_augmentation: AugmentationRule,
|
13
11
|
num_classes: int,
|
14
12
|
num_ir: int,
|
15
13
|
mixups: list[int] | None = None,
|
16
|
-
) -> tuple[
|
14
|
+
) -> tuple[list[AugmentedTarget], list[AugmentationRule]]:
|
17
15
|
import math
|
18
16
|
|
19
17
|
from .augmentation import get_mixups
|
@@ -64,15 +62,15 @@ def balance_targets(
|
|
64
62
|
|
65
63
|
|
66
64
|
def _get_unused_balancing_augmentation(
|
67
|
-
augmented_targets:
|
68
|
-
targets:
|
69
|
-
target_augmentations:
|
65
|
+
augmented_targets: list[AugmentedTarget],
|
66
|
+
targets: list[TargetFile],
|
67
|
+
target_augmentations: list[AugmentationRule],
|
70
68
|
class_balancing_augmentation: AugmentationRule,
|
71
69
|
target_id: int,
|
72
70
|
mixup: int,
|
73
71
|
num_ir: int,
|
74
72
|
first_cba_id: int,
|
75
|
-
) -> tuple[int,
|
73
|
+
) -> tuple[int, list[AugmentationRule]]:
|
76
74
|
"""Get an unused balancing augmentation for a given target file index"""
|
77
75
|
from dataclasses import asdict
|
78
76
|
|
sonusai/mixture/targets.py
CHANGED
@@ -1,21 +1,20 @@
|
|
1
|
-
from sonusai.mixture.datatypes import
|
1
|
+
from sonusai.mixture.datatypes import AugmentationRule
|
2
2
|
from sonusai.mixture.datatypes import AugmentedTarget
|
3
|
-
from sonusai.mixture.datatypes import
|
4
|
-
from sonusai.mixture.datatypes import TargetFiles
|
3
|
+
from sonusai.mixture.datatypes import TargetFile
|
5
4
|
|
6
5
|
|
7
6
|
def get_augmented_targets(
|
8
|
-
target_files:
|
9
|
-
target_augmentations:
|
7
|
+
target_files: list[TargetFile],
|
8
|
+
target_augmentations: list[AugmentationRule],
|
10
9
|
mixups: list[int] | None = None,
|
11
|
-
) ->
|
10
|
+
) -> list[AugmentedTarget]:
|
12
11
|
from .augmentation import get_augmentation_indices_for_mixup
|
13
12
|
from .augmentation import get_mixups
|
14
13
|
|
15
14
|
if mixups is None:
|
16
15
|
mixups = get_mixups(target_augmentations)
|
17
16
|
|
18
|
-
augmented_targets:
|
17
|
+
augmented_targets: list[AugmentedTarget] = []
|
19
18
|
for mixup in mixups:
|
20
19
|
augmentation_indices = get_augmentation_indices_for_mixup(target_augmentations, mixup)
|
21
20
|
for target_index in range(len(target_files)):
|
@@ -30,15 +29,17 @@ def get_augmented_targets(
|
|
30
29
|
return augmented_targets
|
31
30
|
|
32
31
|
|
33
|
-
def get_class_index_for_augmented_target(augmented_target: AugmentedTarget, targets:
|
32
|
+
def get_class_index_for_augmented_target(augmented_target: AugmentedTarget, targets: list[TargetFile]) -> list[int]:
|
34
33
|
return targets[augmented_target.target_id].class_indices
|
35
34
|
|
36
35
|
|
37
|
-
def get_mixup_for_augmented_target(augmented_target: AugmentedTarget, augmentations:
|
36
|
+
def get_mixup_for_augmented_target(augmented_target: AugmentedTarget, augmentations: list[AugmentationRule]) -> int:
|
38
37
|
return augmentations[augmented_target.target_augmentation_id].mixup
|
39
38
|
|
40
39
|
|
41
|
-
def get_target_ids_for_class_index(
|
40
|
+
def get_target_ids_for_class_index(
|
41
|
+
targets: list[TargetFile], class_index: int, allow_multiple: bool = False
|
42
|
+
) -> list[int]:
|
42
43
|
"""Get a list of target indices containing the given class index.
|
43
44
|
|
44
45
|
If allow_multiple is True, then include targets that contain multiple class indices.
|
@@ -55,9 +56,9 @@ def get_target_ids_for_class_index(targets: TargetFiles, class_index: int, allow
|
|
55
56
|
|
56
57
|
|
57
58
|
def get_augmented_target_ids_for_class_index(
|
58
|
-
augmented_targets:
|
59
|
-
targets:
|
60
|
-
augmentations:
|
59
|
+
augmented_targets: list[AugmentedTarget],
|
60
|
+
targets: list[TargetFile],
|
61
|
+
augmentations: list[AugmentationRule],
|
61
62
|
class_index: int,
|
62
63
|
mixup: int,
|
63
64
|
allow_multiple: bool = False,
|
@@ -79,9 +80,9 @@ def get_augmented_target_ids_for_class_index(
|
|
79
80
|
|
80
81
|
|
81
82
|
def get_augmented_target_ids_by_class(
|
82
|
-
augmented_targets:
|
83
|
-
targets:
|
84
|
-
target_augmentations:
|
83
|
+
augmented_targets: list[AugmentedTarget],
|
84
|
+
targets: list[TargetFile],
|
85
|
+
target_augmentations: list[AugmentationRule],
|
85
86
|
mixup: int,
|
86
87
|
num_classes: int,
|
87
88
|
) -> list[list[int]]:
|
@@ -99,7 +100,9 @@ def get_augmented_target_ids_by_class(
|
|
99
100
|
return indices
|
100
101
|
|
101
102
|
|
102
|
-
def get_target_augmentations_for_mixup(
|
103
|
+
def get_target_augmentations_for_mixup(
|
104
|
+
target_augmentations: list[AugmentationRule], mixup: int
|
105
|
+
) -> list[AugmentationRule]:
|
103
106
|
"""Get target augmentations for a given mixup value
|
104
107
|
|
105
108
|
:param target_augmentations: List of target augmentation rules
|
@@ -110,9 +113,9 @@ def get_target_augmentations_for_mixup(target_augmentations: AugmentationRules,
|
|
110
113
|
|
111
114
|
|
112
115
|
def get_augmented_target_ids_for_mixup(
|
113
|
-
augmented_targets:
|
114
|
-
targets:
|
115
|
-
target_augmentations:
|
116
|
+
augmented_targets: list[AugmentedTarget],
|
117
|
+
targets: list[TargetFile],
|
118
|
+
target_augmentations: list[AugmentationRule],
|
116
119
|
mixup: int,
|
117
120
|
num_classes: int,
|
118
121
|
) -> list[list[int]]:
|
@@ -4,10 +4,16 @@ from sonusai.mixture.datatypes import AudioT
|
|
4
4
|
from sonusai.mixture.datatypes import ImpulseResponseData
|
5
5
|
|
6
6
|
|
7
|
-
def read_impulse_response(
|
7
|
+
def read_impulse_response(
|
8
|
+
name: str | Path,
|
9
|
+
delay_compensation: bool = True,
|
10
|
+
normalize: bool = True,
|
11
|
+
) -> ImpulseResponseData:
|
8
12
|
"""Read impulse response data using torchaudio
|
9
13
|
|
10
14
|
:param name: File name
|
15
|
+
:param delay_compensation: Apply delay compensation
|
16
|
+
:param normalize: Apply normalization
|
11
17
|
:return: ImpulseResponseData object
|
12
18
|
"""
|
13
19
|
import numpy as np
|
@@ -28,14 +34,19 @@ def read_impulse_response(name: str | Path) -> ImpulseResponseData:
|
|
28
34
|
raise OSError(f"Error reading {name}: {e}") from e
|
29
35
|
|
30
36
|
raw = torch.squeeze(raw[0, :])
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
# raw = raw / torch.linalg.vector_norm(raw)
|
37
|
+
|
38
|
+
if delay_compensation:
|
39
|
+
offset = torch.argmax(raw)
|
40
|
+
raw = raw[offset:]
|
36
41
|
|
37
42
|
data = np.array(raw).astype(np.float32)
|
38
|
-
|
43
|
+
|
44
|
+
if normalize:
|
45
|
+
# Inexplicably,
|
46
|
+
# data = data / torch.linalg.vector_norm(data)
|
47
|
+
# causes multiprocessing contexts to hang.
|
48
|
+
# Use np.linalg.norm() instead.
|
49
|
+
data = data / np.linalg.norm(data)
|
39
50
|
|
40
51
|
return ImpulseResponseData(name=str(name), sample_rate=sample_rate, data=data)
|
41
52
|
|
@@ -20,10 +20,9 @@ def apply_augmentation(audio: AudioT, augmentation: Augmentation, frame_length:
|
|
20
20
|
|
21
21
|
effects: list[list[str]] = []
|
22
22
|
|
23
|
-
# TODO
|
24
|
-
#
|
25
|
-
#
|
26
|
-
# or hard-coded into the script?)
|
23
|
+
# TODO: Always normalize and remove normalize from list of available augmentations
|
24
|
+
# Normalize to globally set level (should this be a global config parameter, or hard-coded into the script?)
|
25
|
+
# TODO: Support all sox effects supported by torchaudio (torchaudio.sox_effects.effect_names())
|
27
26
|
if augmentation.normalize is not None:
|
28
27
|
effects.append(["norm", str(augmentation.normalize)])
|
29
28
|
|
sonusai/mixture/truth.py
CHANGED
@@ -1,39 +1,26 @@
|
|
1
|
-
from sonusai.mixture
|
2
|
-
from sonusai.mixture
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
def truth_function(
|
8
|
-
target_audio: AudioT,
|
9
|
-
noise_audio: AudioT,
|
10
|
-
mixture_audio: AudioT,
|
11
|
-
config: TruthConfig,
|
12
|
-
feature: str,
|
13
|
-
num_classes: int,
|
14
|
-
class_indices: list[int],
|
15
|
-
target_gain: float,
|
16
|
-
) -> Truth:
|
1
|
+
from sonusai.mixture import MixtureDatabase
|
2
|
+
from sonusai.mixture import Truth
|
3
|
+
|
4
|
+
|
5
|
+
def truth_function(mixdb: MixtureDatabase, m_id: int) -> list[Truth]:
|
6
|
+
from sonusai.mixture import TruthDict
|
17
7
|
from sonusai.mixture import truth_functions
|
18
8
|
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
raise AttributeError(f"Unsupported truth function: {config.function}") from e
|
35
|
-
except Exception as e:
|
36
|
-
raise RuntimeError(f"Error in truth function '{config.function}': {e}") from e
|
9
|
+
result: list[Truth] = []
|
10
|
+
for target_index in range(len(mixdb.mixture(m_id).targets)):
|
11
|
+
truth: TruthDict = {}
|
12
|
+
target_file = mixdb.target_file(mixdb.mixture(m_id).targets[target_index].file_id)
|
13
|
+
for name, config in target_file.truth_configs.items():
|
14
|
+
try:
|
15
|
+
truth[name] = getattr(truth_functions, config.function)(mixdb, m_id, target_index, config.config)
|
16
|
+
except AttributeError as e:
|
17
|
+
raise AttributeError(f"Unsupported truth function: {config.function}") from e
|
18
|
+
except Exception as e:
|
19
|
+
raise RuntimeError(f"Error in truth function '{config.function}': {e}") from e
|
20
|
+
|
21
|
+
result.append(truth)
|
22
|
+
|
23
|
+
return result
|
37
24
|
|
38
25
|
|
39
26
|
def get_truth_indices_for_mixid(mixdb: MixtureDatabase, mixid: int) -> list[int]:
|
@@ -22,6 +22,12 @@ from .energy import snr_f_validate
|
|
22
22
|
from .file import file
|
23
23
|
from .file import file_parameters
|
24
24
|
from .file import file_validate
|
25
|
+
from .metadata import metadata
|
26
|
+
from .metadata import metadata_parameters
|
27
|
+
from .metadata import metadata_validate
|
28
|
+
from .metrics import metrics
|
29
|
+
from .metrics import metrics_parameters
|
30
|
+
from .metrics import metrics_validate
|
25
31
|
from .phoneme import phoneme
|
26
32
|
from .phoneme import phoneme_parameters
|
27
33
|
from .phoneme import phoneme_validate
|
@@ -1,22 +1,32 @@
|
|
1
|
-
from sonusai.mixture
|
2
|
-
from sonusai.mixture
|
3
|
-
from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
|
1
|
+
from sonusai.mixture import MixtureDatabase
|
2
|
+
from sonusai.mixture import Truth
|
4
3
|
|
5
4
|
|
6
|
-
def _core(
|
5
|
+
def _core(mixdb: MixtureDatabase, m_id: int, target_index: int, parameters: int, polar: bool) -> Truth:
|
7
6
|
import numpy as np
|
7
|
+
import torch
|
8
|
+
from pyaaware import ForwardTransform
|
9
|
+
from pyaaware import feature_forward_transform_config
|
10
|
+
from pyaaware import feature_inverse_transform_config
|
8
11
|
|
9
|
-
|
10
|
-
|
12
|
+
target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
|
13
|
+
t_ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
14
|
+
n_ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
11
15
|
|
12
|
-
frames =
|
13
|
-
|
16
|
+
frames = t_ft.frames(target_audio)
|
17
|
+
if mixdb.mixture(m_id).target_gain(target_index) == 0:
|
18
|
+
return np.zeros((frames, parameters), dtype=np.float32)
|
19
|
+
|
20
|
+
noise_audio = torch.from_numpy(mixdb.mixture_noise(m_id))
|
21
|
+
|
22
|
+
frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
|
23
|
+
|
24
|
+
frames = len(target_audio) // frame_size
|
25
|
+
truth = np.empty((frames, t_ft.bins * 2), dtype=np.float32)
|
14
26
|
for frame in range(frames):
|
15
|
-
offset = frame *
|
16
|
-
target_f =
|
17
|
-
|
18
|
-
)
|
19
|
-
noise_f = config.noise_fft.execute(data.noise_audio[offset : offset + config.frame_size]).astype(np.complex64)
|
27
|
+
offset = frame * frame_size
|
28
|
+
target_f = t_ft.execute(target_audio[offset : offset + frame_size])[0].numpy().astype(np.complex64)
|
29
|
+
noise_f = n_ft.execute(noise_audio[offset : offset + frame_size])[0].numpy().astype(np.complex64)
|
20
30
|
mixture_f = target_f + noise_f
|
21
31
|
|
22
32
|
crm_data = np.empty(target_f.shape, dtype=np.complex64)
|
@@ -31,8 +41,8 @@ def _core(data: TruthFunctionData, config: TruthFunctionConfig, polar: bool) ->
|
|
31
41
|
else:
|
32
42
|
crm_data[it.multi_index] = num / den
|
33
43
|
|
34
|
-
truth[frame, :
|
35
|
-
truth[frame,
|
44
|
+
truth[frame, : t_ft.bins] = np.absolute(crm_data) if polar else np.real(crm_data)
|
45
|
+
truth[frame, t_ft.bins :] = np.angle(crm_data) if polar else np.imag(crm_data)
|
36
46
|
|
37
47
|
return truth
|
38
48
|
|
@@ -41,11 +51,14 @@ def crm_validate(_config: dict) -> None:
|
|
41
51
|
pass
|
42
52
|
|
43
53
|
|
44
|
-
def crm_parameters(
|
45
|
-
|
54
|
+
def crm_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
55
|
+
from pyaaware import ForwardTransform
|
56
|
+
from pyaaware import feature_forward_transform_config
|
46
57
|
|
58
|
+
return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
|
47
59
|
|
48
|
-
|
60
|
+
|
61
|
+
def crm(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
|
49
62
|
"""Complex ratio mask truth generation function
|
50
63
|
|
51
64
|
Calculates the true complex ratio mask (CRM) truth which is a complex number
|
@@ -55,25 +68,27 @@ def crm(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
|
55
68
|
|
56
69
|
Output shape: [:, 2 * bins]
|
57
70
|
"""
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
return _core(data=data, config=config, polar=False)
|
71
|
+
return _core(
|
72
|
+
mixdb=mixdb,
|
73
|
+
m_id=m_id,
|
74
|
+
target_index=target_index,
|
75
|
+
parameters=crm_parameters(mixdb.feature, mixdb.num_classes, _config),
|
76
|
+
polar=False,
|
77
|
+
)
|
66
78
|
|
67
79
|
|
68
80
|
def crmp_validate(_config: dict) -> None:
|
69
81
|
pass
|
70
82
|
|
71
83
|
|
72
|
-
def crmp_parameters(
|
73
|
-
|
84
|
+
def crmp_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
85
|
+
from pyaaware import ForwardTransform
|
86
|
+
from pyaaware import feature_forward_transform_config
|
74
87
|
|
88
|
+
return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
|
75
89
|
|
76
|
-
|
90
|
+
|
91
|
+
def crmp(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
|
77
92
|
"""Complex ratio mask polar truth generation function
|
78
93
|
|
79
94
|
Same as the crm function except the results are magnitude and phase
|
@@ -81,11 +96,10 @@ def crmp(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
|
81
96
|
|
82
97
|
Output shape: [:, bins]
|
83
98
|
"""
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
return _core(data=data, config=config, polar=True)
|
99
|
+
return _core(
|
100
|
+
mixdb=mixdb,
|
101
|
+
m_id=m_id,
|
102
|
+
target_index=target_index,
|
103
|
+
parameters=crmp_parameters(mixdb.feature, mixdb.num_classes, _config),
|
104
|
+
polar=True,
|
105
|
+
)
|