sonusai 0.19.6__py3-none-any.whl → 0.19.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +1 -1
- sonusai/aawscd_probwrite.py +1 -1
- sonusai/calc_metric_spenh.py +1 -1
- sonusai/genft.py +29 -14
- sonusai/genmetrics.py +60 -42
- sonusai/genmix.py +41 -29
- sonusai/genmixdb.py +56 -64
- sonusai/metrics/calc_class_weights.py +1 -3
- sonusai/metrics/calc_optimal_thresholds.py +2 -2
- sonusai/metrics/calc_phase_distance.py +1 -1
- sonusai/metrics/calc_speech.py +6 -6
- sonusai/metrics/class_summary.py +6 -15
- sonusai/metrics/confusion_matrix_summary.py +11 -27
- sonusai/metrics/one_hot.py +3 -3
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/mixture/__init__.py +2 -17
- sonusai/mixture/augmentation.py +5 -6
- sonusai/mixture/class_count.py +1 -1
- sonusai/mixture/config.py +36 -46
- sonusai/mixture/data_io.py +30 -1
- sonusai/mixture/datatypes.py +29 -40
- sonusai/mixture/db_datatypes.py +1 -1
- sonusai/mixture/feature.py +3 -23
- sonusai/mixture/generation.py +161 -204
- sonusai/mixture/helpers.py +29 -187
- sonusai/mixture/mixdb.py +386 -159
- sonusai/mixture/soundfile_audio.py +1 -1
- sonusai/mixture/sox_audio.py +4 -4
- sonusai/mixture/sox_augmentation.py +1 -1
- sonusai/mixture/target_class_balancing.py +9 -11
- sonusai/mixture/targets.py +23 -20
- sonusai/mixture/torchaudio_audio.py +18 -7
- sonusai/mixture/torchaudio_augmentation.py +3 -4
- sonusai/mixture/truth.py +21 -34
- sonusai/mixture/truth_functions/__init__.py +6 -0
- sonusai/mixture/truth_functions/crm.py +51 -37
- sonusai/mixture/truth_functions/energy.py +95 -50
- sonusai/mixture/truth_functions/file.py +12 -8
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +4 -5
- sonusai/mixture/truth_functions/sed.py +32 -23
- sonusai/mixture/truth_functions/target.py +62 -29
- sonusai/mkwav.py +20 -19
- sonusai/queries/queries.py +9 -15
- sonusai/speech/l2arctic.py +6 -2
- sonusai/summarize_metric_spenh.py +1 -1
- sonusai/utils/__init__.py +1 -0
- sonusai/utils/asr_functions/aaware_whisper.py +1 -1
- sonusai/utils/audio_devices.py +27 -18
- sonusai/utils/docstring.py +6 -3
- sonusai/utils/energy_f.py +5 -3
- sonusai/utils/human_readable_size.py +6 -6
- sonusai/utils/load_object.py +15 -0
- sonusai/utils/onnx_utils.py +2 -2
- sonusai/utils/print_mixture_details.py +3 -3
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/METADATA +2 -2
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/RECORD +60 -58
- sonusai/mixture/truth_functions/datatypes.py +0 -37
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/WHEEL +0 -0
- {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/entry_points.txt +0 -0
@@ -1,20 +1,44 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
|
-
from sonusai.mixture
|
4
|
-
from sonusai.mixture
|
5
|
-
from sonusai.
|
3
|
+
from sonusai.mixture import MixtureDatabase
|
4
|
+
from sonusai.mixture import Truth
|
5
|
+
from sonusai.utils import load_object
|
6
|
+
|
7
|
+
|
8
|
+
def _core(
|
9
|
+
mixdb: MixtureDatabase,
|
10
|
+
m_id: int,
|
11
|
+
target_index: int,
|
12
|
+
config: dict,
|
13
|
+
parameters: int,
|
14
|
+
mapped: bool,
|
15
|
+
snr: bool,
|
16
|
+
) -> Truth:
|
17
|
+
from os.path import join
|
6
18
|
|
19
|
+
import torch
|
20
|
+
from pyaaware import ForwardTransform
|
21
|
+
from pyaaware import feature_forward_transform_config
|
7
22
|
|
8
|
-
def _core(data: TruthFunctionData, config: TruthFunctionConfig, mapped: bool, snr: bool) -> Truth:
|
9
23
|
from sonusai.utils import compute_energy_f
|
10
24
|
|
11
|
-
|
25
|
+
target_audio = mixdb.mixture_targets(m_id)[target_index]
|
26
|
+
ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
27
|
+
|
28
|
+
frames = ft.frames(torch.from_numpy(target_audio))
|
29
|
+
|
30
|
+
if mixdb.mixture(m_id).target_gain(target_index) == 0:
|
31
|
+
return np.zeros((frames, parameters), dtype=np.float32)
|
32
|
+
|
33
|
+
noise_audio = mixdb.mixture_noise(m_id)
|
34
|
+
|
35
|
+
target_energy = compute_energy_f(time_domain=target_audio, transform=ft)
|
12
36
|
noise_energy = None
|
13
37
|
if snr:
|
14
|
-
noise_energy = compute_energy_f(time_domain=
|
38
|
+
noise_energy = compute_energy_f(time_domain=noise_audio, transform=ft)
|
15
39
|
|
16
40
|
frames = len(target_energy)
|
17
|
-
truth = np.empty((frames,
|
41
|
+
truth = np.empty((frames, ft.bins), dtype=np.float32)
|
18
42
|
for frame in range(frames):
|
19
43
|
tmp = target_energy[frame]
|
20
44
|
|
@@ -26,7 +50,9 @@ def _core(data: TruthFunctionData, config: TruthFunctionConfig, mapped: bool, sn
|
|
26
50
|
tmp = np.nan_to_num(tmp, nan=-np.inf, posinf=np.inf, neginf=-np.inf)
|
27
51
|
|
28
52
|
if mapped:
|
29
|
-
|
53
|
+
snr_db_mean = load_object(join(mixdb.location, config["snr_db_mean"]))
|
54
|
+
snr_db_std = load_object(join(mixdb.location, config["snr_db_std"]))
|
55
|
+
tmp = _calculate_mapped_snr_f(tmp, snr_db_mean, snr_db_std)
|
30
56
|
|
31
57
|
truth[frame] = tmp
|
32
58
|
|
@@ -52,11 +78,14 @@ def energy_f_validate(_config: dict) -> None:
|
|
52
78
|
pass
|
53
79
|
|
54
80
|
|
55
|
-
def energy_f_parameters(
|
56
|
-
|
81
|
+
def energy_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
82
|
+
from pyaaware import ForwardTransform
|
83
|
+
from pyaaware import feature_forward_transform_config
|
84
|
+
|
85
|
+
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
57
86
|
|
58
87
|
|
59
|
-
def energy_f(
|
88
|
+
def energy_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
|
60
89
|
"""Frequency domain energy truth generation function
|
61
90
|
|
62
91
|
Calculates the true energy per bin:
|
@@ -67,23 +96,29 @@ def energy_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
|
67
96
|
|
68
97
|
Output shape: [:, bins]
|
69
98
|
"""
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
99
|
+
return _core(
|
100
|
+
mixdb=mixdb,
|
101
|
+
m_id=m_id,
|
102
|
+
target_index=target_index,
|
103
|
+
config=config,
|
104
|
+
parameters=energy_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
105
|
+
mapped=False,
|
106
|
+
snr=False,
|
107
|
+
)
|
76
108
|
|
77
109
|
|
78
110
|
def snr_f_validate(_config: dict) -> None:
|
79
111
|
pass
|
80
112
|
|
81
113
|
|
82
|
-
def snr_f_parameters(
|
83
|
-
|
114
|
+
def snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
115
|
+
from pyaaware import ForwardTransform
|
116
|
+
from pyaaware import feature_forward_transform_config
|
84
117
|
|
118
|
+
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
85
119
|
|
86
|
-
|
120
|
+
|
121
|
+
def snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
|
87
122
|
"""Frequency domain SNR truth function documentation
|
88
123
|
|
89
124
|
Calculates the true SNR per bin:
|
@@ -94,54 +129,58 @@ def snr_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
|
94
129
|
|
95
130
|
Output shape: [:, bins]
|
96
131
|
"""
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
132
|
+
return _core(
|
133
|
+
mixdb=mixdb,
|
134
|
+
m_id=m_id,
|
135
|
+
target_index=target_index,
|
136
|
+
config=config,
|
137
|
+
parameters=snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
138
|
+
mapped=False,
|
139
|
+
snr=True,
|
140
|
+
)
|
141
|
+
|
142
|
+
|
143
|
+
def mapped_snr_f_validate(config: dict) -> None:
|
144
|
+
if len(config) == 0:
|
107
145
|
raise AttributeError("mapped_snr_f truth function is missing config")
|
108
146
|
|
109
147
|
for parameter in ("snr_db_mean", "snr_db_std"):
|
110
|
-
if parameter not in config
|
148
|
+
if parameter not in config:
|
111
149
|
raise AttributeError(f"mapped_snr_f truth function is missing required '{parameter}'")
|
112
150
|
|
113
|
-
if len(config.config[parameter]) != config.target_fft.bins:
|
114
|
-
raise ValueError(
|
115
|
-
f"mapped_snr_f truth function '{parameter}' does not have {config.target_fft.bins} elements"
|
116
|
-
)
|
117
151
|
|
152
|
+
def mapped_snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
153
|
+
from pyaaware import ForwardTransform
|
154
|
+
from pyaaware import feature_forward_transform_config
|
118
155
|
|
119
|
-
|
120
|
-
return config.target_fft.bins
|
156
|
+
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
121
157
|
|
122
158
|
|
123
|
-
def mapped_snr_f(
|
159
|
+
def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
|
124
160
|
"""Frequency domain mapped SNR truth function documentation
|
125
161
|
|
126
162
|
Output shape: [:, bins]
|
127
163
|
"""
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
164
|
+
return _core(
|
165
|
+
mixdb=mixdb,
|
166
|
+
m_id=m_id,
|
167
|
+
target_index=target_index,
|
168
|
+
config=config,
|
169
|
+
parameters=mapped_snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
170
|
+
mapped=True,
|
171
|
+
snr=True,
|
172
|
+
)
|
134
173
|
|
135
174
|
|
136
175
|
def energy_t_validate(_config: dict) -> None:
|
137
176
|
pass
|
138
177
|
|
139
178
|
|
140
|
-
def energy_t_parameters(_config:
|
179
|
+
def energy_t_parameters(_feature: str, _num_classes: int, _config: dict) -> int:
|
141
180
|
return 1
|
142
181
|
|
143
182
|
|
144
|
-
def energy_t(
|
183
|
+
def energy_t(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
|
145
184
|
"""Time domain energy truth function documentation
|
146
185
|
|
147
186
|
Calculates the true time domain energy of each frame:
|
@@ -164,10 +203,16 @@ def energy_t(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
|
164
203
|
transform config.
|
165
204
|
"""
|
166
205
|
import torch
|
206
|
+
from pyaaware import ForwardTransform
|
207
|
+
from pyaaware import feature_forward_transform_config
|
208
|
+
|
209
|
+
target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
|
210
|
+
|
211
|
+
ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
167
212
|
|
168
|
-
frames =
|
169
|
-
parameters =
|
170
|
-
if
|
213
|
+
frames = ft.frames(target_audio)
|
214
|
+
parameters = energy_f_parameters(mixdb.feature, mixdb.num_classes, _config)
|
215
|
+
if mixdb.mixture(m_id).target_gain(target_index) == 0:
|
171
216
|
return np.zeros((frames, parameters), dtype=np.float32)
|
172
217
|
|
173
|
-
return
|
218
|
+
return ft.execute_all(target_audio)[1].numpy()
|
@@ -1,6 +1,5 @@
|
|
1
|
-
from sonusai.mixture
|
2
|
-
from sonusai.mixture
|
3
|
-
from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
|
1
|
+
from sonusai.mixture import MixtureDatabase
|
2
|
+
from sonusai.mixture import Truth
|
4
3
|
|
5
4
|
|
6
5
|
def file_validate(config: dict) -> None:
|
@@ -17,28 +16,33 @@ def file_validate(config: dict) -> None:
|
|
17
16
|
raise ValueError("Truth file does not contain truth_f dataset")
|
18
17
|
|
19
18
|
|
20
|
-
def file_parameters(config:
|
19
|
+
def file_parameters(_feature: str, _num_classes: int, config: dict) -> int:
|
21
20
|
import h5py
|
22
21
|
import numpy as np
|
23
22
|
|
24
|
-
with h5py.File(config
|
23
|
+
with h5py.File(config["file"], "r") as f:
|
25
24
|
truth = np.array(f["truth_f"])
|
26
25
|
|
27
26
|
return truth.shape[-1]
|
28
27
|
|
29
28
|
|
30
|
-
def file(
|
29
|
+
def file(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
|
31
30
|
"""file truth function documentation"""
|
32
31
|
import h5py
|
33
32
|
import numpy as np
|
33
|
+
from pyaaware import feature_inverse_transform_config
|
34
|
+
|
35
|
+
target_audio = mixdb.mixture_targets(m_id)[target_index]
|
34
36
|
|
35
|
-
|
37
|
+
frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
|
38
|
+
|
39
|
+
with h5py.File(config["file"], "r") as f:
|
36
40
|
truth = np.array(f["truth_f"])
|
37
41
|
|
38
42
|
if truth.ndim != 2:
|
39
43
|
raise ValueError("Truth file data is not 2 dimensions")
|
40
44
|
|
41
|
-
if truth.shape[0] != len(
|
45
|
+
if truth.shape[0] != len(target_audio) // frame_size:
|
42
46
|
raise ValueError("Truth file does not contain the right amount of frames")
|
43
47
|
|
44
48
|
return truth
|
@@ -0,0 +1,24 @@
|
|
1
|
+
from sonusai.mixture import MixtureDatabase
|
2
|
+
from sonusai.mixture import Truth
|
3
|
+
|
4
|
+
|
5
|
+
def metadata_validate(config: dict) -> None:
|
6
|
+
if len(config) == 0:
|
7
|
+
raise AttributeError("metadata truth function is missing config")
|
8
|
+
|
9
|
+
parameters = ["tier"]
|
10
|
+
for parameter in parameters:
|
11
|
+
if parameter not in config:
|
12
|
+
raise AttributeError(f"metadata truth function is missing required '{parameter}'")
|
13
|
+
|
14
|
+
|
15
|
+
def metadata_parameters(_feature: str, _num_classes: int, _config: dict) -> int | None:
|
16
|
+
return None
|
17
|
+
|
18
|
+
|
19
|
+
def metadata(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
|
20
|
+
"""Metadata truth generation function
|
21
|
+
|
22
|
+
Retrieves metadata from target.
|
23
|
+
"""
|
24
|
+
return mixdb.mixture_speech_metadata(m_id, config["tier"])[target_index]
|
@@ -0,0 +1,28 @@
|
|
1
|
+
from sonusai.mixture import MixtureDatabase
|
2
|
+
from sonusai.mixture import Truth
|
3
|
+
|
4
|
+
|
5
|
+
def metrics_validate(config: dict) -> None:
|
6
|
+
if len(config) == 0:
|
7
|
+
raise AttributeError("metrics truth function is missing config")
|
8
|
+
|
9
|
+
parameters = ["metric"]
|
10
|
+
for parameter in parameters:
|
11
|
+
if parameter not in config:
|
12
|
+
raise AttributeError(f"metrics truth function is missing required '{parameter}'")
|
13
|
+
|
14
|
+
|
15
|
+
def metrics_parameters(_feature: str, _num_classes: int, _config: dict) -> int | None:
|
16
|
+
return None
|
17
|
+
|
18
|
+
|
19
|
+
def metrics(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
|
20
|
+
"""Metadata truth generation function
|
21
|
+
|
22
|
+
Retrieves metrics from target.
|
23
|
+
"""
|
24
|
+
if not isinstance(config["metric"], list):
|
25
|
+
m = [config["metric"]]
|
26
|
+
else:
|
27
|
+
m = config["metric"]
|
28
|
+
return mixdb.mixture_metrics(m_id, m)[0][target_index]
|
@@ -1,17 +1,16 @@
|
|
1
|
-
from sonusai.mixture
|
2
|
-
from sonusai.mixture
|
3
|
-
from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
|
1
|
+
from sonusai.mixture import MixtureDatabase
|
2
|
+
from sonusai.mixture import Truth
|
4
3
|
|
5
4
|
|
6
5
|
def phoneme_validate(_config: dict) -> None:
|
7
6
|
raise NotImplementedError("Truth function phoneme is not supported yet")
|
8
7
|
|
9
8
|
|
10
|
-
def phoneme_parameters(_config:
|
9
|
+
def phoneme_parameters(_feature: str, _num_classes: int, _config: dict) -> int:
|
11
10
|
raise NotImplementedError("Truth function phoneme is not supported yet")
|
12
11
|
|
13
12
|
|
14
|
-
def phoneme(
|
13
|
+
def phoneme(_mixdb: MixtureDatabase, _m_id: int, _target_index: int, _config: dict) -> Truth:
|
15
14
|
"""Read in .txt transcript and run a Python function to generate text grid data
|
16
15
|
(indicating which phonemes are active). Then generate truth based on this data and put
|
17
16
|
in the correct classes based on the index in the config.
|
@@ -1,12 +1,5 @@
|
|
1
|
-
from sonusai.mixture
|
2
|
-
from sonusai.mixture
|
3
|
-
from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
|
4
|
-
|
5
|
-
|
6
|
-
def _strictly_decreasing(list_to_check: list) -> bool:
|
7
|
-
from itertools import pairwise
|
8
|
-
|
9
|
-
return all(x > y for x, y in pairwise(list_to_check))
|
1
|
+
from sonusai.mixture import MixtureDatabase
|
2
|
+
from sonusai.mixture import Truth
|
10
3
|
|
11
4
|
|
12
5
|
def sed_validate(config: dict) -> None:
|
@@ -23,11 +16,11 @@ def sed_validate(config: dict) -> None:
|
|
23
16
|
raise ValueError(f"sed truth function 'thresholds' are not strictly decreasing: {thresholds}")
|
24
17
|
|
25
18
|
|
26
|
-
def sed_parameters(
|
27
|
-
return
|
19
|
+
def sed_parameters(_feature: str, num_classes: int, _config: dict) -> int:
|
20
|
+
return num_classes
|
28
21
|
|
29
22
|
|
30
|
-
def sed(
|
23
|
+
def sed(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
|
31
24
|
"""Sound energy detection truth generation function
|
32
25
|
|
33
26
|
Calculates sound energy detection truth using simple 3 threshold
|
@@ -62,30 +55,46 @@ def sed(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
|
62
55
|
import numpy as np
|
63
56
|
import torch
|
64
57
|
from pyaaware import SED
|
58
|
+
from pyaaware import ForwardTransform
|
59
|
+
from pyaaware import feature_forward_transform_config
|
60
|
+
from pyaaware import feature_inverse_transform_config
|
61
|
+
|
62
|
+
target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
|
63
|
+
|
64
|
+
frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
|
65
65
|
|
66
|
-
|
67
|
-
raise ValueError(f"Number of samples in audio is not a multiple of {config.frame_size}")
|
66
|
+
ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
68
67
|
|
69
|
-
|
70
|
-
|
71
|
-
|
68
|
+
if len(target_audio) % frame_size != 0:
|
69
|
+
raise ValueError(f"Number of samples in audio is not a multiple of {frame_size}")
|
70
|
+
|
71
|
+
frames = ft.frames(target_audio)
|
72
|
+
parameters = sed_parameters(mixdb.feature, mixdb.num_classes, config)
|
73
|
+
target_gain = mixdb.mixture(m_id).target_gain(target_index)
|
74
|
+
if target_gain == 0:
|
72
75
|
return np.zeros((frames, parameters), dtype=np.float32)
|
73
76
|
|
74
77
|
# SED wants 1-based indices
|
75
78
|
s = SED(
|
76
|
-
thresholds=config
|
77
|
-
index=
|
78
|
-
frame_size=
|
79
|
-
num_classes=
|
79
|
+
thresholds=config["thresholds"],
|
80
|
+
index=mixdb.target_file(mixdb.mixture(m_id).targets[target_index].file_id).class_indices,
|
81
|
+
frame_size=frame_size,
|
82
|
+
num_classes=mixdb.num_classes,
|
80
83
|
)
|
81
84
|
|
82
85
|
# Back out target gain
|
83
|
-
target_audio =
|
86
|
+
target_audio = target_audio / target_gain
|
84
87
|
|
85
88
|
# Compute energy
|
86
|
-
target_energy =
|
89
|
+
target_energy = ft.execute_all(target_audio)[1].numpy()
|
87
90
|
|
88
91
|
if frames != target_energy.shape[0]:
|
89
92
|
raise ValueError("Incorrect frames calculation in sed truth function")
|
90
93
|
|
91
94
|
return s.execute_all(target_energy)
|
95
|
+
|
96
|
+
|
97
|
+
def _strictly_decreasing(list_to_check: list) -> bool:
|
98
|
+
from itertools import pairwise
|
99
|
+
|
100
|
+
return all(x > y for x, y in pairwise(list_to_check))
|
@@ -1,21 +1,24 @@
|
|
1
|
-
from sonusai.mixture
|
2
|
-
from sonusai.mixture
|
3
|
-
from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
|
4
|
-
from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
|
1
|
+
from sonusai.mixture import MixtureDatabase
|
2
|
+
from sonusai.mixture import Truth
|
5
3
|
|
6
4
|
|
7
5
|
def target_f_validate(_config: dict) -> None:
|
8
6
|
pass
|
9
7
|
|
10
8
|
|
11
|
-
def target_f_parameters(
|
12
|
-
|
13
|
-
|
9
|
+
def target_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
10
|
+
from pyaaware import ForwardTransform
|
11
|
+
from pyaaware import feature_forward_transform_config
|
14
12
|
|
15
|
-
|
13
|
+
ft = ForwardTransform(**feature_forward_transform_config(feature))
|
16
14
|
|
15
|
+
if ft.ttype == "tdac-co":
|
16
|
+
return ft.bins
|
17
17
|
|
18
|
-
|
18
|
+
return ft.bins * 2
|
19
|
+
|
20
|
+
|
21
|
+
def target_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
|
19
22
|
"""Frequency domain target truth function
|
20
23
|
|
21
24
|
Calculates the true transform of the target using the STFT
|
@@ -26,23 +29,34 @@ def target_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
|
26
29
|
[:, bins] (target real only for tdac-co)
|
27
30
|
"""
|
28
31
|
import torch
|
32
|
+
from pyaaware import ForwardTransform
|
33
|
+
from pyaaware import feature_forward_transform_config
|
34
|
+
|
35
|
+
ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
29
36
|
|
30
|
-
|
31
|
-
|
37
|
+
target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
|
38
|
+
|
39
|
+
target_freq = ft.execute_all(target_audio)[0].numpy()
|
40
|
+
return _stack_real_imag(target_freq, ft.ttype)
|
32
41
|
|
33
42
|
|
34
43
|
def target_mixture_f_validate(_config: dict) -> None:
|
35
44
|
pass
|
36
45
|
|
37
46
|
|
38
|
-
def target_mixture_f_parameters(
|
39
|
-
|
40
|
-
|
47
|
+
def target_mixture_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
48
|
+
from pyaaware import ForwardTransform
|
49
|
+
from pyaaware import feature_forward_transform_config
|
50
|
+
|
51
|
+
ft = ForwardTransform(**feature_forward_transform_config(feature))
|
52
|
+
|
53
|
+
if ft.ttype == "tdac-co":
|
54
|
+
return ft.bins * 2
|
41
55
|
|
42
|
-
return
|
56
|
+
return ft.bins * 4
|
43
57
|
|
44
58
|
|
45
|
-
def target_mixture_f(
|
59
|
+
def target_mixture_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
|
46
60
|
"""Frequency domain target and mixture truth function
|
47
61
|
|
48
62
|
Calculates the true transform of the target and the mixture
|
@@ -55,14 +69,21 @@ def target_mixture_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Tr
|
|
55
69
|
"""
|
56
70
|
import numpy as np
|
57
71
|
import torch
|
72
|
+
from pyaaware import ForwardTransform
|
73
|
+
from pyaaware import feature_forward_transform_config
|
58
74
|
|
59
|
-
|
60
|
-
|
75
|
+
ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
76
|
+
|
77
|
+
target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
|
78
|
+
mixture_audio = torch.from_numpy(mixdb.mixture_mixture(m_id))
|
79
|
+
|
80
|
+
target_freq = ft.execute_all(torch.from_numpy(target_audio))[0].numpy()
|
81
|
+
mixture_freq = ft.execute_all(torch.from_numpy(mixture_audio))[0].numpy()
|
61
82
|
|
62
83
|
frames, bins = target_freq.shape
|
63
84
|
truth = np.empty((frames, bins * 4), dtype=np.float32)
|
64
|
-
truth[:, : bins * 2] = _stack_real_imag(target_freq,
|
65
|
-
truth[:, bins * 2 :] = _stack_real_imag(mixture_freq,
|
85
|
+
truth[:, : bins * 2] = _stack_real_imag(target_freq, ft.ttype)
|
86
|
+
truth[:, bins * 2 :] = _stack_real_imag(mixture_freq, ft.ttype)
|
66
87
|
return truth
|
67
88
|
|
68
89
|
|
@@ -70,11 +91,14 @@ def target_swin_f_validate(_config: dict) -> None:
|
|
70
91
|
pass
|
71
92
|
|
72
93
|
|
73
|
-
def target_swin_f_parameters(
|
74
|
-
|
94
|
+
def target_swin_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
95
|
+
from pyaaware import ForwardTransform
|
96
|
+
from pyaaware import feature_forward_transform_config
|
97
|
+
|
98
|
+
return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
|
75
99
|
|
76
100
|
|
77
|
-
def target_swin_f(
|
101
|
+
def target_swin_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
|
78
102
|
"""Frequency domain target with synthesis window truth function
|
79
103
|
|
80
104
|
Calculates the true transform of the target using the STFT
|
@@ -85,20 +109,29 @@ def target_swin_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth
|
|
85
109
|
Output shape: [:, 2 * bins] (stacked real, imag)
|
86
110
|
"""
|
87
111
|
import numpy as np
|
112
|
+
import torch
|
113
|
+
from pyaaware import ForwardTransform
|
114
|
+
from pyaaware import InverseTransform
|
115
|
+
from pyaaware import feature_forward_transform_config
|
116
|
+
from pyaaware import feature_inverse_transform_config
|
88
117
|
|
89
118
|
from sonusai.utils import stack_complex
|
90
119
|
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
120
|
+
ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
121
|
+
it = InverseTransform(**feature_inverse_transform_config(mixdb.feature))
|
122
|
+
|
123
|
+
target_audio = mixdb.mixture_targets(m_id)[target_index]
|
124
|
+
|
125
|
+
truth = np.empty((len(target_audio) // ft.overlap, ft.bins * 2), dtype=np.float32)
|
126
|
+
for idx, offset in enumerate(range(0, len(target_audio), ft.overlap)):
|
127
|
+
audio_frame = torch.from_numpy(np.multiply(target_audio[offset : offset + ft.overlap], it.window))
|
128
|
+
target_freq = ft.execute(audio_frame)[0].numpy()
|
96
129
|
truth[idx] = stack_complex(target_freq)
|
97
130
|
|
98
131
|
return truth
|
99
132
|
|
100
133
|
|
101
|
-
def _stack_real_imag(data:
|
134
|
+
def _stack_real_imag(data: Truth, ttype: str) -> Truth:
|
102
135
|
import numpy as np
|
103
136
|
|
104
137
|
from sonusai.utils import stack_complex
|