sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +16 -3
- sonusai/audiofe.py +241 -77
- sonusai/calc_metric_spenh.py +71 -73
- sonusai/config/__init__.py +3 -0
- sonusai/config/config.py +61 -0
- sonusai/config/config.yml +20 -0
- sonusai/config/constants.py +8 -0
- sonusai/constants.py +11 -0
- sonusai/data/genmixdb.yml +21 -36
- sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
- sonusai/deprecated/plot.py +4 -5
- sonusai/doc/doc.py +4 -4
- sonusai/doc.py +11 -4
- sonusai/genft.py +43 -45
- sonusai/genmetrics.py +25 -19
- sonusai/genmix.py +54 -82
- sonusai/genmixdb.py +88 -264
- sonusai/ir_metric.py +30 -34
- sonusai/lsdb.py +41 -48
- sonusai/main.py +15 -22
- sonusai/metrics/calc_audio_stats.py +4 -293
- sonusai/metrics/calc_class_weights.py +4 -4
- sonusai/metrics/calc_optimal_thresholds.py +8 -5
- sonusai/metrics/calc_pesq.py +2 -2
- sonusai/metrics/calc_segsnr_f.py +4 -4
- sonusai/metrics/calc_speech.py +25 -13
- sonusai/metrics/class_summary.py +7 -7
- sonusai/metrics/confusion_matrix_summary.py +5 -5
- sonusai/metrics/one_hot.py +4 -4
- sonusai/metrics/snr_summary.py +7 -7
- sonusai/metrics_summary.py +38 -45
- sonusai/mixture/__init__.py +4 -104
- sonusai/mixture/audio.py +10 -39
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/config.py +251 -271
- sonusai/mixture/constants.py +35 -39
- sonusai/mixture/data_io.py +25 -36
- sonusai/mixture/db_datatypes.py +58 -22
- sonusai/mixture/effects.py +386 -0
- sonusai/mixture/feature.py +7 -11
- sonusai/mixture/generation.py +478 -628
- sonusai/mixture/helpers.py +82 -184
- sonusai/mixture/ir_delay.py +3 -4
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +6 -12
- sonusai/mixture/mixdb.py +910 -729
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +2 -2
- sonusai/mixture/truth.py +17 -15
- sonusai/mixture/truth_functions/crm.py +12 -12
- sonusai/mixture/truth_functions/energy.py +22 -22
- sonusai/mixture/truth_functions/file.py +5 -5
- sonusai/mixture/truth_functions/metadata.py +4 -4
- sonusai/mixture/truth_functions/metrics.py +4 -4
- sonusai/mixture/truth_functions/phoneme.py +3 -3
- sonusai/mixture/truth_functions/sed.py +11 -13
- sonusai/mixture/truth_functions/target.py +10 -10
- sonusai/mkwav.py +26 -29
- sonusai/onnx_predict.py +240 -88
- sonusai/queries/__init__.py +2 -2
- sonusai/queries/queries.py +38 -34
- sonusai/speech/librispeech.py +1 -1
- sonusai/speech/mcgill.py +1 -1
- sonusai/speech/timit.py +2 -2
- sonusai/summarize_metric_spenh.py +10 -17
- sonusai/utils/__init__.py +7 -1
- sonusai/utils/asl_p56.py +2 -2
- sonusai/utils/asr.py +2 -2
- sonusai/utils/asr_functions/aaware_whisper.py +4 -5
- sonusai/utils/choice.py +31 -0
- sonusai/utils/compress.py +1 -1
- sonusai/utils/dataclass_from_dict.py +19 -1
- sonusai/utils/energy_f.py +3 -3
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/onnx_utils.py +3 -17
- sonusai/utils/print_mixture_details.py +21 -19
- sonusai/utils/{temp_seed.py → rand.py} +3 -3
- sonusai/utils/read_predict_data.py +2 -2
- sonusai/utils/reshape.py +3 -3
- sonusai/utils/stratified_shuffle_split.py +3 -3
- sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
- sonusai/utils/write_audio.py +2 -2
- sonusai/vars.py +11 -4
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
- sonusai-1.0.2.dist-info/RECORD +138 -0
- sonusai/mixture/augmentation.py +0 -444
- sonusai/mixture/class_count.py +0 -15
- sonusai/mixture/eq_rule_is_valid.py +0 -45
- sonusai/mixture/target_class_balancing.py +0 -107
- sonusai/mixture/targets.py +0 -175
- sonusai-0.20.3.dist-info/RECORD +0 -128
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
- {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
sonusai/mixture/truth.py
CHANGED
@@ -1,33 +1,35 @@
|
|
1
|
-
from
|
2
|
-
from
|
1
|
+
from ..datatypes import Truth
|
2
|
+
from ..datatypes import TruthsDict
|
3
|
+
from .mixdb import MixtureDatabase
|
3
4
|
|
4
5
|
|
5
|
-
def truth_function(mixdb: MixtureDatabase, m_id: int) ->
|
6
|
-
from
|
7
|
-
from
|
6
|
+
def truth_function(mixdb: MixtureDatabase, m_id: int) -> TruthsDict:
|
7
|
+
from ..datatypes import TruthDict
|
8
|
+
from . import truth_functions
|
8
9
|
|
9
|
-
result:
|
10
|
-
for
|
10
|
+
result: TruthsDict = {}
|
11
|
+
for category, source in mixdb.mixture(m_id).all_sources.items():
|
11
12
|
truth: TruthDict = {}
|
12
|
-
|
13
|
-
for name, config in
|
13
|
+
source_file = mixdb.source_file(source.file_id)
|
14
|
+
for name, config in source_file.truth_configs.items():
|
14
15
|
try:
|
15
|
-
truth[name] = getattr(truth_functions, config.function)(mixdb, m_id,
|
16
|
+
truth[name] = getattr(truth_functions, config.function)(mixdb, m_id, category, config.config)
|
16
17
|
except AttributeError as e:
|
17
18
|
raise AttributeError(f"Unsupported truth function: {config.function}") from e
|
18
19
|
except Exception as e:
|
19
20
|
raise RuntimeError(f"Error in truth function '{config.function}': {e}") from e
|
20
21
|
|
21
|
-
|
22
|
+
if truth:
|
23
|
+
result[category] = truth
|
22
24
|
|
23
25
|
return result
|
24
26
|
|
25
27
|
|
26
|
-
def
|
27
|
-
"""Get a list of
|
28
|
+
def get_class_indices_for_mixid(mixdb: MixtureDatabase, mixid: int) -> list[int]:
|
29
|
+
"""Get a list of class indices for a given mixid."""
|
28
30
|
indices: list[int] = []
|
29
|
-
for
|
30
|
-
indices.append(*mixdb.
|
31
|
+
for source_id in [source.file_id for source in mixdb.mixture(mixid).all_sources.values()]:
|
32
|
+
indices.append(*mixdb.source_file(source_id).class_indices)
|
31
33
|
|
32
34
|
return sorted(set(indices))
|
33
35
|
|
@@ -1,31 +1,31 @@
|
|
1
|
-
from
|
2
|
-
from
|
1
|
+
from ...datatypes import Truth
|
2
|
+
from ..mixdb import MixtureDatabase
|
3
3
|
|
4
4
|
|
5
|
-
def _core(mixdb: MixtureDatabase, m_id: int,
|
5
|
+
def _core(mixdb: MixtureDatabase, m_id: int, category: str, parameters: int, polar: bool) -> Truth:
|
6
6
|
import numpy as np
|
7
7
|
import torch
|
8
8
|
from pyaaware import ForwardTransform
|
9
9
|
from pyaaware import feature_forward_transform_config
|
10
10
|
from pyaaware import feature_inverse_transform_config
|
11
11
|
|
12
|
-
|
12
|
+
source_audio = torch.from_numpy(mixdb.mixture_sources(m_id)[category])
|
13
13
|
t_ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
14
14
|
n_ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
15
15
|
|
16
|
-
frames = t_ft.frames(
|
17
|
-
if mixdb.mixture(m_id).
|
16
|
+
frames = t_ft.frames(source_audio)
|
17
|
+
if mixdb.mixture(m_id).all_sources[category].snr_gain == 0:
|
18
18
|
return np.zeros((frames, parameters), dtype=np.float32)
|
19
19
|
|
20
20
|
noise_audio = torch.from_numpy(mixdb.mixture_noise(m_id))
|
21
21
|
|
22
22
|
frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
|
23
23
|
|
24
|
-
frames = len(
|
24
|
+
frames = len(source_audio) // frame_size
|
25
25
|
truth = np.empty((frames, t_ft.bins * 2), dtype=np.float32)
|
26
26
|
for frame in range(frames):
|
27
27
|
offset = frame * frame_size
|
28
|
-
target_f = t_ft.execute(
|
28
|
+
target_f = t_ft.execute(source_audio[offset : offset + frame_size])[0].numpy().astype(np.complex64)
|
29
29
|
noise_f = n_ft.execute(noise_audio[offset : offset + frame_size])[0].numpy().astype(np.complex64)
|
30
30
|
mixture_f = target_f + noise_f
|
31
31
|
|
@@ -58,7 +58,7 @@ def crm_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
|
58
58
|
return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
|
59
59
|
|
60
60
|
|
61
|
-
def crm(mixdb: MixtureDatabase, m_id: int,
|
61
|
+
def crm(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
|
62
62
|
"""Complex ratio mask truth generation function
|
63
63
|
|
64
64
|
Calculates the true complex ratio mask (CRM) truth which is a complex number
|
@@ -71,7 +71,7 @@ def crm(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) ->
|
|
71
71
|
return _core(
|
72
72
|
mixdb=mixdb,
|
73
73
|
m_id=m_id,
|
74
|
-
|
74
|
+
category=category,
|
75
75
|
parameters=crm_parameters(mixdb.feature, mixdb.num_classes, _config),
|
76
76
|
polar=False,
|
77
77
|
)
|
@@ -88,7 +88,7 @@ def crmp_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
|
88
88
|
return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
|
89
89
|
|
90
90
|
|
91
|
-
def crmp(mixdb: MixtureDatabase, m_id: int,
|
91
|
+
def crmp(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
|
92
92
|
"""Complex ratio mask polar truth generation function
|
93
93
|
|
94
94
|
Same as the crm function except the results are magnitude and phase
|
@@ -99,7 +99,7 @@ def crmp(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) ->
|
|
99
99
|
return _core(
|
100
100
|
mixdb=mixdb,
|
101
101
|
m_id=m_id,
|
102
|
-
|
102
|
+
category=category,
|
103
103
|
parameters=crmp_parameters(mixdb.feature, mixdb.num_classes, _config),
|
104
104
|
polar=True,
|
105
105
|
)
|
@@ -1,14 +1,14 @@
|
|
1
1
|
import numpy as np
|
2
2
|
|
3
|
-
from
|
4
|
-
from
|
5
|
-
from
|
3
|
+
from ...datatypes import Truth
|
4
|
+
from ...utils.load_object import load_object
|
5
|
+
from ..mixdb import MixtureDatabase
|
6
6
|
|
7
7
|
|
8
8
|
def _core(
|
9
9
|
mixdb: MixtureDatabase,
|
10
10
|
m_id: int,
|
11
|
-
|
11
|
+
category: str,
|
12
12
|
config: dict,
|
13
13
|
parameters: int,
|
14
14
|
mapped: bool,
|
@@ -21,27 +21,27 @@ def _core(
|
|
21
21
|
from pyaaware import ForwardTransform
|
22
22
|
from pyaaware import feature_forward_transform_config
|
23
23
|
|
24
|
-
from
|
24
|
+
from ...utils.energy_f import compute_energy_f
|
25
25
|
|
26
|
-
|
26
|
+
source_audio = mixdb.mixture_sources(m_id)[category]
|
27
27
|
ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
28
28
|
|
29
|
-
frames = ft.frames(torch.from_numpy(
|
29
|
+
frames = ft.frames(torch.from_numpy(source_audio))
|
30
30
|
|
31
|
-
if mixdb.mixture(m_id).
|
31
|
+
if mixdb.mixture(m_id).all_sources[category].snr_gain == 0:
|
32
32
|
return np.zeros((frames, parameters), dtype=np.float32)
|
33
33
|
|
34
34
|
noise_audio = mixdb.mixture_noise(m_id)
|
35
35
|
|
36
|
-
|
36
|
+
source_energy = compute_energy_f(time_domain=source_audio, transform=ft)
|
37
37
|
noise_energy = None
|
38
38
|
if snr:
|
39
39
|
noise_energy = compute_energy_f(time_domain=noise_audio, transform=ft)
|
40
40
|
|
41
|
-
frames = len(
|
41
|
+
frames = len(source_energy)
|
42
42
|
truth = np.empty((frames, ft.bins), dtype=np.float32)
|
43
43
|
for frame in range(frames):
|
44
|
-
tmp =
|
44
|
+
tmp = source_energy[frame]
|
45
45
|
|
46
46
|
if noise_energy is not None:
|
47
47
|
old_err = np.seterr(divide="ignore", invalid="ignore")
|
@@ -86,7 +86,7 @@ def energy_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
|
86
86
|
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
87
87
|
|
88
88
|
|
89
|
-
def energy_f(mixdb: MixtureDatabase, m_id: int,
|
89
|
+
def energy_f(mixdb: MixtureDatabase, m_id: int, category: str, config: dict, use_cache: bool = True) -> Truth:
|
90
90
|
"""Frequency domain energy truth generation function
|
91
91
|
|
92
92
|
Calculates the true energy per bin:
|
@@ -100,7 +100,7 @@ def energy_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict,
|
|
100
100
|
return _core(
|
101
101
|
mixdb=mixdb,
|
102
102
|
m_id=m_id,
|
103
|
-
|
103
|
+
category=category,
|
104
104
|
config=config,
|
105
105
|
parameters=energy_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
106
106
|
mapped=False,
|
@@ -120,7 +120,7 @@ def snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
|
120
120
|
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
121
121
|
|
122
122
|
|
123
|
-
def snr_f(mixdb: MixtureDatabase, m_id: int,
|
123
|
+
def snr_f(mixdb: MixtureDatabase, m_id: int, category: str, config: dict, use_cache: bool = True) -> Truth:
|
124
124
|
"""Frequency domain SNR truth function documentation
|
125
125
|
|
126
126
|
Calculates the true SNR per bin:
|
@@ -134,7 +134,7 @@ def snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict, us
|
|
134
134
|
return _core(
|
135
135
|
mixdb=mixdb,
|
136
136
|
m_id=m_id,
|
137
|
-
|
137
|
+
category=category,
|
138
138
|
config=config,
|
139
139
|
parameters=snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
140
140
|
mapped=False,
|
@@ -159,7 +159,7 @@ def mapped_snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> i
|
|
159
159
|
return ForwardTransform(**feature_forward_transform_config(feature)).bins
|
160
160
|
|
161
161
|
|
162
|
-
def mapped_snr_f(mixdb: MixtureDatabase, m_id: int,
|
162
|
+
def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, category: str, config: dict, use_cache: bool = True) -> Truth:
|
163
163
|
"""Frequency domain mapped SNR truth function documentation
|
164
164
|
|
165
165
|
Output shape: [:, bins]
|
@@ -167,7 +167,7 @@ def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: d
|
|
167
167
|
return _core(
|
168
168
|
mixdb=mixdb,
|
169
169
|
m_id=m_id,
|
170
|
-
|
170
|
+
category=category,
|
171
171
|
config=config,
|
172
172
|
parameters=mapped_snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
|
173
173
|
mapped=True,
|
@@ -184,7 +184,7 @@ def energy_t_parameters(_feature: str, _num_classes: int, _config: dict) -> int:
|
|
184
184
|
return 1
|
185
185
|
|
186
186
|
|
187
|
-
def energy_t(mixdb: MixtureDatabase, m_id: int,
|
187
|
+
def energy_t(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
|
188
188
|
"""Time domain energy truth function documentation
|
189
189
|
|
190
190
|
Calculates the true time domain energy of each frame:
|
@@ -210,13 +210,13 @@ def energy_t(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict
|
|
210
210
|
from pyaaware import ForwardTransform
|
211
211
|
from pyaaware import feature_forward_transform_config
|
212
212
|
|
213
|
-
|
213
|
+
source_audio = torch.from_numpy(mixdb.mixture_sources(m_id)[category])
|
214
214
|
|
215
215
|
ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
216
216
|
|
217
|
-
frames = ft.frames(
|
217
|
+
frames = ft.frames(source_audio)
|
218
218
|
parameters = energy_f_parameters(mixdb.feature, mixdb.num_classes, _config)
|
219
|
-
if mixdb.mixture(m_id).
|
219
|
+
if mixdb.mixture(m_id).all_sources[category].snr_gain == 0:
|
220
220
|
return np.zeros((frames, parameters), dtype=np.float32)
|
221
221
|
|
222
|
-
return ft.execute_all(
|
222
|
+
return ft.execute_all(source_audio)[1].numpy()
|
@@ -1,5 +1,5 @@
|
|
1
|
-
from
|
2
|
-
from
|
1
|
+
from ...datatypes import Truth
|
2
|
+
from ..mixdb import MixtureDatabase
|
3
3
|
|
4
4
|
|
5
5
|
def file_validate(config: dict) -> None:
|
@@ -26,13 +26,13 @@ def file_parameters(_feature: str, _num_classes: int, config: dict) -> int:
|
|
26
26
|
return truth.shape[-1]
|
27
27
|
|
28
28
|
|
29
|
-
def file(mixdb: MixtureDatabase, m_id: int,
|
29
|
+
def file(mixdb: MixtureDatabase, m_id: int, category: str, config: dict) -> Truth:
|
30
30
|
"""file truth function documentation"""
|
31
31
|
import h5py
|
32
32
|
import numpy as np
|
33
33
|
from pyaaware import feature_inverse_transform_config
|
34
34
|
|
35
|
-
|
35
|
+
source_audio = mixdb.mixture_sources(m_id)[category]
|
36
36
|
|
37
37
|
frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
|
38
38
|
|
@@ -42,7 +42,7 @@ def file(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) ->
|
|
42
42
|
if truth.ndim != 2:
|
43
43
|
raise ValueError("Truth file data is not 2 dimensions")
|
44
44
|
|
45
|
-
if truth.shape[0] != len(
|
45
|
+
if truth.shape[0] != len(source_audio) // frame_size:
|
46
46
|
raise ValueError("Truth file does not contain the right amount of frames")
|
47
47
|
|
48
48
|
return truth
|
@@ -1,5 +1,5 @@
|
|
1
|
-
from
|
2
|
-
from
|
1
|
+
from ...datatypes import Truth
|
2
|
+
from ..mixdb import MixtureDatabase
|
3
3
|
|
4
4
|
|
5
5
|
def metadata_validate(config: dict) -> None:
|
@@ -16,9 +16,9 @@ def metadata_parameters(_feature: str, _num_classes: int, _config: dict) -> int
|
|
16
16
|
return None
|
17
17
|
|
18
18
|
|
19
|
-
def metadata(mixdb: MixtureDatabase, m_id: int,
|
19
|
+
def metadata(mixdb: MixtureDatabase, m_id: int, category: str, config: dict) -> Truth:
|
20
20
|
"""Metadata truth generation function
|
21
21
|
|
22
22
|
Retrieves metadata from target.
|
23
23
|
"""
|
24
|
-
return mixdb.mixture_speech_metadata(m_id, config["tier"])[
|
24
|
+
return mixdb.mixture_speech_metadata(m_id, config["tier"])[category]
|
@@ -1,5 +1,5 @@
|
|
1
|
-
from
|
2
|
-
from
|
1
|
+
from ...datatypes import Truth
|
2
|
+
from ..mixdb import MixtureDatabase
|
3
3
|
|
4
4
|
|
5
5
|
def metrics_validate(config: dict) -> None:
|
@@ -16,7 +16,7 @@ def metrics_parameters(_feature: str, _num_classes: int, _config: dict) -> int |
|
|
16
16
|
return None
|
17
17
|
|
18
18
|
|
19
|
-
def metrics(mixdb: MixtureDatabase, m_id: int,
|
19
|
+
def metrics(mixdb: MixtureDatabase, m_id: int, category: str, config: dict) -> Truth:
|
20
20
|
"""Metadata truth generation function
|
21
21
|
|
22
22
|
Retrieves metrics from target.
|
@@ -25,4 +25,4 @@ def metrics(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict)
|
|
25
25
|
m = [config["metric"]]
|
26
26
|
else:
|
27
27
|
m = config["metric"]
|
28
|
-
return mixdb.mixture_metrics(m_id, m)[m[0]][
|
28
|
+
return mixdb.mixture_metrics(m_id, m)[m[0]][category]
|
@@ -1,5 +1,5 @@
|
|
1
|
-
from
|
2
|
-
from
|
1
|
+
from ...datatypes import Truth
|
2
|
+
from ..mixdb import MixtureDatabase
|
3
3
|
|
4
4
|
|
5
5
|
def phoneme_validate(_config: dict) -> None:
|
@@ -10,7 +10,7 @@ def phoneme_parameters(_feature: str, _num_classes: int, _config: dict) -> int:
|
|
10
10
|
raise NotImplementedError("Truth function phoneme is not supported yet")
|
11
11
|
|
12
12
|
|
13
|
-
def phoneme(_mixdb: MixtureDatabase, _m_id: int,
|
13
|
+
def phoneme(_mixdb: MixtureDatabase, _m_id: int, _category: str, _config: dict) -> Truth:
|
14
14
|
"""Read in .txt transcript and run a Python function to generate text grid data
|
15
15
|
(indicating which phonemes are active). Then generate truth based on this data and put
|
16
16
|
in the correct classes based on the index in the config.
|
@@ -1,5 +1,7 @@
|
|
1
|
-
from
|
2
|
-
|
1
|
+
from numpy.lib.utils import source
|
2
|
+
|
3
|
+
from ...datatypes import Truth
|
4
|
+
from ..mixdb import MixtureDatabase
|
3
5
|
|
4
6
|
|
5
7
|
def sed_validate(config: dict) -> None:
|
@@ -20,7 +22,7 @@ def sed_parameters(_feature: str, num_classes: int, _config: dict) -> int:
|
|
20
22
|
return num_classes
|
21
23
|
|
22
24
|
|
23
|
-
def sed(mixdb: MixtureDatabase, m_id: int,
|
25
|
+
def sed(mixdb: MixtureDatabase, m_id: int, category: str, config: dict) -> Truth:
|
24
26
|
"""Sound energy detection truth generation function
|
25
27
|
|
26
28
|
Calculates sound energy detection truth using simple 3 threshold
|
@@ -59,34 +61,30 @@ def sed(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> T
|
|
59
61
|
from pyaaware import feature_forward_transform_config
|
60
62
|
from pyaaware import feature_inverse_transform_config
|
61
63
|
|
62
|
-
|
64
|
+
source_audio = torch.from_numpy(mixdb.mixture_sources(m_id)[category])
|
63
65
|
|
64
66
|
frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
|
65
67
|
|
66
68
|
ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
67
69
|
|
68
|
-
if len(
|
70
|
+
if len(source_audio) % frame_size != 0:
|
69
71
|
raise ValueError(f"Number of samples in audio is not a multiple of {frame_size}")
|
70
72
|
|
71
|
-
frames = ft.frames(
|
73
|
+
frames = ft.frames(source_audio)
|
72
74
|
parameters = sed_parameters(mixdb.feature, mixdb.num_classes, config)
|
73
|
-
|
74
|
-
if target_gain == 0:
|
75
|
+
if mixdb.mixture(m_id).all_sources[category].snr_gain == 0:
|
75
76
|
return np.zeros((frames, parameters), dtype=np.float32)
|
76
77
|
|
77
78
|
# SED wants 1-based indices
|
78
79
|
s = SED(
|
79
80
|
thresholds=config["thresholds"],
|
80
|
-
index=mixdb.
|
81
|
+
index=mixdb.source_file(mixdb.mixture(m_id).all_sources[category].file_id).class_indices,
|
81
82
|
frame_size=frame_size,
|
82
83
|
num_classes=mixdb.num_classes,
|
83
84
|
)
|
84
85
|
|
85
|
-
# Back out target gain
|
86
|
-
target_audio = target_audio / target_gain
|
87
|
-
|
88
86
|
# Compute energy
|
89
|
-
target_energy = ft.execute_all(
|
87
|
+
target_energy = ft.execute_all(source_audio)[1].numpy()
|
90
88
|
|
91
89
|
if frames != target_energy.shape[0]:
|
92
90
|
raise ValueError("Incorrect frames calculation in sed truth function")
|
@@ -1,5 +1,5 @@
|
|
1
|
-
from
|
2
|
-
from
|
1
|
+
from ...datatypes import Truth
|
2
|
+
from ..mixdb import MixtureDatabase
|
3
3
|
|
4
4
|
|
5
5
|
def target_f_validate(_config: dict) -> None:
|
@@ -18,7 +18,7 @@ def target_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
|
|
18
18
|
return ft.bins * 2
|
19
19
|
|
20
20
|
|
21
|
-
def target_f(mixdb: MixtureDatabase, m_id: int,
|
21
|
+
def target_f(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
|
22
22
|
"""Frequency domain target truth function
|
23
23
|
|
24
24
|
Calculates the true transform of the target using the STFT
|
@@ -34,7 +34,7 @@ def target_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict
|
|
34
34
|
|
35
35
|
ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
36
36
|
|
37
|
-
target_audio = torch.from_numpy(mixdb.
|
37
|
+
target_audio = torch.from_numpy(mixdb.mixture_sources(m_id)[category])
|
38
38
|
|
39
39
|
target_freq = ft.execute_all(target_audio)[0].numpy()
|
40
40
|
return _stack_real_imag(target_freq, ft.ttype)
|
@@ -56,7 +56,7 @@ def target_mixture_f_parameters(feature: str, _num_classes: int, _config: dict)
|
|
56
56
|
return ft.bins * 4
|
57
57
|
|
58
58
|
|
59
|
-
def target_mixture_f(mixdb: MixtureDatabase, m_id: int,
|
59
|
+
def target_mixture_f(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
|
60
60
|
"""Frequency domain target and mixture truth function
|
61
61
|
|
62
62
|
Calculates the true transform of the target and the mixture
|
@@ -74,7 +74,7 @@ def target_mixture_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _conf
|
|
74
74
|
|
75
75
|
ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
76
76
|
|
77
|
-
target_audio = torch.from_numpy(mixdb.
|
77
|
+
target_audio = torch.from_numpy(mixdb.mixture_sources(m_id)[category])
|
78
78
|
mixture_audio = torch.from_numpy(mixdb.mixture_mixture(m_id))
|
79
79
|
|
80
80
|
target_freq = ft.execute_all(torch.from_numpy(target_audio))[0].numpy()
|
@@ -98,7 +98,7 @@ def target_swin_f_parameters(feature: str, _num_classes: int, _config: dict) ->
|
|
98
98
|
return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
|
99
99
|
|
100
100
|
|
101
|
-
def target_swin_f(mixdb: MixtureDatabase, m_id: int,
|
101
|
+
def target_swin_f(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
|
102
102
|
"""Frequency domain target with synthesis window truth function
|
103
103
|
|
104
104
|
Calculates the true transform of the target using the STFT
|
@@ -115,12 +115,12 @@ def target_swin_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config:
|
|
115
115
|
from pyaaware import feature_forward_transform_config
|
116
116
|
from pyaaware import feature_inverse_transform_config
|
117
117
|
|
118
|
-
from
|
118
|
+
from ...utils.stacked_complex import stack_complex
|
119
119
|
|
120
120
|
ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
|
121
121
|
it = InverseTransform(**feature_inverse_transform_config(mixdb.feature))
|
122
122
|
|
123
|
-
target_audio = mixdb.
|
123
|
+
target_audio = mixdb.mixture_sources(m_id)[category]
|
124
124
|
|
125
125
|
truth = np.empty((len(target_audio) // ft.overlap, ft.bins * 2), dtype=np.float32)
|
126
126
|
for idx, offset in enumerate(range(0, len(target_audio), ft.overlap)):
|
@@ -134,7 +134,7 @@ def target_swin_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config:
|
|
134
134
|
def _stack_real_imag(data: Truth, ttype: str) -> Truth:
|
135
135
|
import numpy as np
|
136
136
|
|
137
|
-
from
|
137
|
+
from ...utils.stacked_complex import stack_complex
|
138
138
|
|
139
139
|
if ttype == "tdac-co":
|
140
140
|
return np.real(data)
|
sonusai/mkwav.py
CHANGED
@@ -6,8 +6,8 @@ options:
|
|
6
6
|
-h, --help
|
7
7
|
-v, --verbose Be verbose.
|
8
8
|
-i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
|
9
|
-
-t, --
|
10
|
-
-s, --
|
9
|
+
-t, --source Write source file.
|
10
|
+
-s, --sources Write sources files.
|
11
11
|
-n, --noise Write noise file.
|
12
12
|
|
13
13
|
The mkwav command creates WAV files from a SonusAI database.
|
@@ -19,30 +19,17 @@ Inputs:
|
|
19
19
|
Outputs the following to the mixture database directory:
|
20
20
|
<id>
|
21
21
|
mixture.wav: mixture
|
22
|
-
|
23
|
-
|
22
|
+
source.wav: source (optional)
|
23
|
+
source_<c>.wav: source <category> (optional)
|
24
24
|
noise.wav: noise (optional)
|
25
25
|
metadata.txt
|
26
26
|
mkwav.log
|
27
27
|
|
28
28
|
"""
|
29
29
|
|
30
|
-
import signal
|
31
|
-
|
32
|
-
|
33
|
-
def signal_handler(_sig, _frame):
|
34
|
-
import sys
|
35
|
-
|
36
|
-
from sonusai import logger
|
37
|
-
|
38
|
-
logger.info("Canceled due to keyboard interrupt")
|
39
|
-
sys.exit(1)
|
40
|
-
|
41
|
-
|
42
|
-
signal.signal(signal.SIGINT, signal_handler)
|
43
|
-
|
44
30
|
|
45
31
|
def _process_mixture(m_id: int, location: str, write_target: bool, write_targets: bool, write_noise: bool) -> None:
|
32
|
+
from os import makedirs
|
46
33
|
from os.path import join
|
47
34
|
|
48
35
|
from sonusai.mixture import MixtureDatabase
|
@@ -52,14 +39,16 @@ def _process_mixture(m_id: int, location: str, write_target: bool, write_targets
|
|
52
39
|
|
53
40
|
mixdb = MixtureDatabase(location)
|
54
41
|
|
55
|
-
|
42
|
+
index = mixdb.mixture(m_id).name
|
43
|
+
location = join(mixdb.location, "mixture", index)
|
44
|
+
makedirs(location, exist_ok=True)
|
56
45
|
|
57
46
|
write_audio(name=join(location, "mixture.wav"), audio=float_to_int16(mixdb.mixture_mixture(m_id)))
|
58
47
|
if write_target:
|
59
|
-
write_audio(name=join(location, "
|
48
|
+
write_audio(name=join(location, "source.wav"), audio=float_to_int16(mixdb.mixture_source(m_id)))
|
60
49
|
if write_targets:
|
61
|
-
for
|
62
|
-
write_audio(name=join(location, f"
|
50
|
+
for category, source in mixdb.mixture_sources(m_id).items():
|
51
|
+
write_audio(name=join(location, f"sources_{category}.wav"), audio=float_to_int16(source))
|
63
52
|
if write_noise:
|
64
53
|
write_audio(name=join(location, "noise.wav"), audio=float_to_int16(mixdb.mixture_noise(m_id)))
|
65
54
|
|
@@ -69,15 +58,15 @@ def _process_mixture(m_id: int, location: str, write_target: bool, write_targets
|
|
69
58
|
def main() -> None:
|
70
59
|
from docopt import docopt
|
71
60
|
|
72
|
-
import
|
61
|
+
from sonusai import __version__ as sai_version
|
73
62
|
from sonusai.utils import trim_docstring
|
74
63
|
|
75
|
-
args = docopt(trim_docstring(__doc__), version=
|
64
|
+
args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
|
76
65
|
|
77
66
|
verbose = args["--verbose"]
|
78
67
|
mixid = args["--mixid"]
|
79
|
-
|
80
|
-
|
68
|
+
write_source = args["--source"]
|
69
|
+
write_sources = args["--sources"]
|
81
70
|
write_noise = args["--noise"]
|
82
71
|
location = args["LOC"]
|
83
72
|
|
@@ -118,12 +107,13 @@ def main() -> None:
|
|
118
107
|
partial(
|
119
108
|
_process_mixture,
|
120
109
|
location=location,
|
121
|
-
write_target=
|
122
|
-
write_targets=
|
110
|
+
write_target=write_source,
|
111
|
+
write_targets=write_sources,
|
123
112
|
write_noise=write_noise,
|
124
113
|
),
|
125
114
|
mixid,
|
126
115
|
progress=progress,
|
116
|
+
# no_par=True,
|
127
117
|
)
|
128
118
|
progress.close()
|
129
119
|
|
@@ -135,4 +125,11 @@ def main() -> None:
|
|
135
125
|
|
136
126
|
|
137
127
|
if __name__ == "__main__":
|
138
|
-
|
128
|
+
from sonusai import exception_handler
|
129
|
+
from sonusai.utils import register_keyboard_interrupt
|
130
|
+
|
131
|
+
register_keyboard_interrupt()
|
132
|
+
try:
|
133
|
+
main()
|
134
|
+
except Exception as e:
|
135
|
+
exception_handler(e)
|