sonusai 0.18.9__py3-none-any.whl → 0.19.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +93 -77
- sonusai/genmetrics.py +59 -46
- sonusai/genmix.py +116 -104
- sonusai/genmixdb.py +194 -153
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +15 -17
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +19 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +40 -38
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +669 -477
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +52 -85
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +40 -27
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
- sonusai-0.19.5.dist-info/RECORD +125 -0
- {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.9.dist-info/RECORD +0 -125
- {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
@@ -1,65 +1,91 @@
|
|
1
1
|
from sonusai.mixture.datatypes import Truth
|
2
|
-
from sonusai.mixture.truth_functions.
|
2
|
+
from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
|
3
|
+
from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
|
3
4
|
|
4
5
|
|
5
|
-
def
|
6
|
+
def _strictly_decreasing(list_to_check: list) -> bool:
|
7
|
+
from itertools import pairwise
|
8
|
+
|
9
|
+
return all(x > y for x, y in pairwise(list_to_check))
|
10
|
+
|
11
|
+
|
12
|
+
def sed_validate(config: dict) -> None:
|
13
|
+
if len(config) == 0:
|
14
|
+
raise AttributeError("sed truth function is missing config")
|
15
|
+
|
16
|
+
parameters = ["thresholds"]
|
17
|
+
for parameter in parameters:
|
18
|
+
if parameter not in config:
|
19
|
+
raise AttributeError(f"sed truth function is missing required '{parameter}'")
|
20
|
+
|
21
|
+
thresholds = config["thresholds"]
|
22
|
+
if not _strictly_decreasing(thresholds):
|
23
|
+
raise ValueError(f"sed truth function 'thresholds' are not strictly decreasing: {thresholds}")
|
24
|
+
|
25
|
+
|
26
|
+
def sed_parameters(config: TruthFunctionConfig) -> int:
|
27
|
+
return config.num_classes
|
28
|
+
|
29
|
+
|
30
|
+
def sed(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
6
31
|
"""Sound energy detection truth generation function
|
7
32
|
|
8
|
-
Calculates sound energy detection truth using simple 3 threshold
|
9
|
-
hysteresis algorithm. SED outputs 3 possible probabilities of
|
10
|
-
sound presence: 1.0 present, 0.5 (transition/uncertain), 0 not
|
11
|
-
present. The output values will be assigned to the truth output
|
12
|
-
at the index specified in the
|
33
|
+
Calculates sound energy detection truth using simple 3 threshold
|
34
|
+
hysteresis algorithm. SED outputs 3 possible probabilities of
|
35
|
+
sound presence: 1.0 present, 0.5 (transition/uncertain), 0 not
|
36
|
+
present. The output values will be assigned to the truth output
|
37
|
+
at the index specified in the config.
|
38
|
+
|
39
|
+
Output shape: [:, num_classes]
|
40
|
+
|
41
|
+
index Truth index <int> or list(<int>)
|
13
42
|
|
14
|
-
|
43
|
+
index indicates which truth fields should be set.
|
44
|
+
0 indicates none, 1 is first element in truth output vector, 2 2nd element, etc.
|
15
45
|
|
16
|
-
|
17
|
-
|
46
|
+
Examples:
|
47
|
+
index = 5 truth in class 5, truth(4, 1)
|
48
|
+
index = [1, 5] truth in classes 1 and 5, truth([0, 4], 1)
|
18
49
|
|
19
|
-
|
20
|
-
|
21
|
-
|
50
|
+
In mutually-exclusive mode, a frame is expected to only
|
51
|
+
belong to one class and thus all probabilities must sum to
|
52
|
+
1. This is effectively truth for a classifier with multichannel
|
53
|
+
softmax output.
|
54
|
+
|
55
|
+
For multi-label classification each class is an individual
|
56
|
+
probability for that class and any given frame can be
|
57
|
+
assigned to multiple classes/labels, i.e., the classes are
|
58
|
+
not mutually-exclusive. For example, a NN classifier with
|
59
|
+
multichannel sigmoid output. In this case, index could
|
60
|
+
also be a vector with multiple class indices.
|
22
61
|
"""
|
23
62
|
import numpy as np
|
24
63
|
import torch
|
25
64
|
from pyaaware import SED
|
26
65
|
|
27
|
-
|
28
|
-
|
29
|
-
if data.config.config is None:
|
30
|
-
raise SonusAIError('Truth function SED missing config')
|
66
|
+
if len(data.target_audio) % config.frame_size != 0:
|
67
|
+
raise ValueError(f"Number of samples in audio is not a multiple of {config.frame_size}")
|
31
68
|
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
thresholds = data.config.config['thresholds']
|
38
|
-
if not _strictly_decreasing(thresholds):
|
39
|
-
raise SonusAIError(f'Truth function SED thresholds are not strictly decreasing: {thresholds}')
|
40
|
-
|
41
|
-
if len(data.target_audio) % data.frame_size != 0:
|
42
|
-
raise SonusAIError(f'Number of samples in audio is not a multiple of {data.frame_size}')
|
69
|
+
frames = config.target_fft.frames(data.target_audio)
|
70
|
+
parameters = sed_parameters(config)
|
71
|
+
if config.target_gain == 0:
|
72
|
+
return np.zeros((frames, parameters), dtype=np.float32)
|
43
73
|
|
44
74
|
# SED wants 1-based indices
|
45
|
-
s = SED(
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
75
|
+
s = SED(
|
76
|
+
thresholds=config.config["thresholds"],
|
77
|
+
index=config.class_indices,
|
78
|
+
frame_size=config.frame_size,
|
79
|
+
num_classes=config.num_classes,
|
80
|
+
)
|
50
81
|
|
51
|
-
|
52
|
-
|
53
|
-
if len(energy_t) != len(data.offsets):
|
54
|
-
raise SonusAIError(f'Number of frames in energy_t, {len(energy_t)},'
|
55
|
-
f' is not number of frames in truth, {len(data.offsets)}')
|
82
|
+
# Back out target gain
|
83
|
+
target_audio = data.target_audio / config.target_gain
|
56
84
|
|
57
|
-
|
58
|
-
|
59
|
-
data.truth[offset:offset + data.frame_size] = np.reshape(new_truth, (1, len(new_truth)))
|
85
|
+
# Compute energy
|
86
|
+
target_energy = config.target_fft.execute_all(torch.from_numpy(target_audio))[1].numpy()
|
60
87
|
|
61
|
-
|
88
|
+
if frames != target_energy.shape[0]:
|
89
|
+
raise ValueError("Incorrect frames calculation in sed truth function")
|
62
90
|
|
63
|
-
|
64
|
-
def _strictly_decreasing(list_to_check: list) -> bool:
|
65
|
-
return all(x > y for x, y in zip(list_to_check, list_to_check[1:]))
|
91
|
+
return s.execute_all(target_energy)
|
@@ -1,146 +1,109 @@
|
|
1
|
-
from sonusai import ForwardTransform
|
2
|
-
|
3
1
|
from sonusai.mixture.datatypes import AudioF
|
4
|
-
from sonusai.mixture.datatypes import AudioT
|
5
2
|
from sonusai.mixture.datatypes import Truth
|
6
|
-
from sonusai.mixture.truth_functions.
|
7
|
-
|
8
|
-
|
9
|
-
def target_f(data: Data) -> Truth:
|
10
|
-
"""Frequency domain target truth function
|
3
|
+
from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
|
4
|
+
from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
|
11
5
|
|
12
|
-
Calculates the true transform of the target using the STFT
|
13
|
-
configuration defined by the feature. This will include a
|
14
|
-
forward transform window if defined by the feature.
|
15
6
|
|
16
|
-
|
17
|
-
|
18
|
-
"""
|
19
|
-
from sonusai import SonusAIError
|
7
|
+
def target_f_validate(_config: dict) -> None:
|
8
|
+
pass
|
20
9
|
|
21
|
-
if data.config.num_classes != data.feature_parameters:
|
22
|
-
raise SonusAIError(f'Invalid num_classes for target_f truth: {data.config.num_classes}')
|
23
10
|
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
offset=offset,
|
28
|
-
frame_size=data.frame_size,
|
29
|
-
zero_based_indices=data.zero_based_indices,
|
30
|
-
bins=data.target_fft.bins,
|
31
|
-
ttype=data.ttype,
|
32
|
-
start=0,
|
33
|
-
truth=data.truth)
|
11
|
+
def target_f_parameters(config: TruthFunctionConfig) -> int:
|
12
|
+
if config.ttype == "tdac-co":
|
13
|
+
return config.target_fft.bins
|
34
14
|
|
35
|
-
return
|
15
|
+
return config.target_fft.bins * 2
|
36
16
|
|
37
17
|
|
38
|
-
|
39
|
-
|
40
|
-
"""Frequency domain target and mixture truth function
|
18
|
+
def target_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
19
|
+
"""Frequency domain target truth function
|
41
20
|
|
42
|
-
Calculates the true transform of the target
|
43
|
-
|
44
|
-
|
45
|
-
feature.
|
21
|
+
Calculates the true transform of the target using the STFT
|
22
|
+
configuration defined by the feature. This will include a
|
23
|
+
forward transform window if defined by the feature.
|
46
24
|
|
47
|
-
Output shape: [:, 2 *
|
48
|
-
|
49
|
-
(mixture stacked real, imag; or real only for tdac-co)
|
25
|
+
Output shape: [:, 2 * bins] (target stacked real, imag) or
|
26
|
+
[:, bins] (target real only for tdac-co)
|
50
27
|
"""
|
51
|
-
|
28
|
+
import torch
|
52
29
|
|
53
|
-
|
54
|
-
|
30
|
+
target_freq = config.target_fft.execute_all(torch.from_numpy(data.target_audio))[0].numpy()
|
31
|
+
return _stack_real_imag(target_freq, config.ttype)
|
55
32
|
|
56
|
-
target_freq = _execute_fft(data.target_audio, data.target_fft, len(data.offsets))
|
57
|
-
mixture_freq = _execute_fft(data.mixture_audio, data.mixture_fft, len(data.offsets))
|
58
33
|
|
59
|
-
|
60
|
-
|
61
|
-
offset=offset,
|
62
|
-
frame_size=data.frame_size,
|
63
|
-
zero_based_indices=data.zero_based_indices,
|
64
|
-
bins=data.target_fft.bins,
|
65
|
-
ttype=data.ttype,
|
66
|
-
start=0,
|
67
|
-
truth=data.truth)
|
34
|
+
def target_mixture_f_validate(_config: dict) -> None:
|
35
|
+
pass
|
68
36
|
|
69
|
-
data.truth = _stack_real_imag(data=mixture_freq[idx],
|
70
|
-
offset=offset,
|
71
|
-
frame_size=data.frame_size,
|
72
|
-
zero_based_indices=data.zero_based_indices,
|
73
|
-
bins=data.target_fft.bins,
|
74
|
-
ttype=data.ttype,
|
75
|
-
start=data.target_fft.bins * 2,
|
76
|
-
truth=data.truth)
|
77
37
|
|
78
|
-
|
38
|
+
def target_mixture_f_parameters(config: TruthFunctionConfig) -> int:
|
39
|
+
if config.ttype == "tdac-co":
|
40
|
+
return config.target_fft.bins * 2
|
79
41
|
|
42
|
+
return config.target_fft.bins * 4
|
80
43
|
|
81
|
-
def target_swin_f(data: Data) -> Truth:
|
82
|
-
"""Frequency domain target with synthesis window truth function
|
83
44
|
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
the
|
45
|
+
def target_mixture_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
46
|
+
"""Frequency domain target and mixture truth function
|
47
|
+
|
48
|
+
Calculates the true transform of the target and the mixture
|
49
|
+
using the STFT configuration defined by the feature. This
|
50
|
+
will include a forward transform window if defined by the
|
51
|
+
feature.
|
88
52
|
|
89
|
-
Output shape: [:,
|
53
|
+
Output shape: [:, 4 * bins] (target stacked real, imag; mixture stacked real, imag) or
|
54
|
+
[:, 2 * bins] (target real; mixture real for tdac-co)
|
90
55
|
"""
|
91
56
|
import numpy as np
|
57
|
+
import torch
|
92
58
|
|
93
|
-
|
59
|
+
target_freq = config.target_fft.execute_all(torch.from_numpy(data.target_audio))[0].numpy()
|
60
|
+
mixture_freq = config.mixture_fft.execute_all(torch.from_numpy(data.mixture_audio))[0].numpy()
|
94
61
|
|
95
|
-
|
96
|
-
|
62
|
+
frames, bins = target_freq.shape
|
63
|
+
truth = np.empty((frames, bins * 4), dtype=np.float32)
|
64
|
+
truth[:, : bins * 2] = _stack_real_imag(target_freq, config.ttype)
|
65
|
+
truth[:, bins * 2 :] = _stack_real_imag(mixture_freq, config.ttype)
|
66
|
+
return truth
|
97
67
|
|
98
|
-
for idx, offset in enumerate(data.offsets):
|
99
|
-
target_freq, _ = data.target_fft.execute(
|
100
|
-
np.multiply(data.target_audio[offset:offset + data.frame_size], data.swin))
|
101
68
|
|
102
|
-
|
103
|
-
|
104
|
-
bins = _get_bin_slice(index, data.target_fft.bins)
|
105
|
-
data.truth[indices, bins] = np.real(target_freq[idx])
|
69
|
+
def target_swin_f_validate(_config: dict) -> None:
|
70
|
+
pass
|
106
71
|
|
107
|
-
bins = _get_bin_slice(bins.stop, data.target_fft.bins)
|
108
|
-
data.truth[indices, bins] = np.imag(target_freq[idx])
|
109
72
|
|
110
|
-
|
73
|
+
def target_swin_f_parameters(config: TruthFunctionConfig) -> int:
|
74
|
+
return config.target_fft.bins * 2
|
111
75
|
|
112
76
|
|
113
|
-
def
|
114
|
-
|
115
|
-
|
77
|
+
def target_swin_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
78
|
+
"""Frequency domain target with synthesis window truth function
|
79
|
+
|
80
|
+
Calculates the true transform of the target using the STFT
|
81
|
+
configuration defined by the feature. This will include a
|
82
|
+
forward transform window if defined by the feature and also
|
83
|
+
the inverse transform (or synthesis) window.
|
116
84
|
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
return freq
|
85
|
+
Output shape: [:, 2 * bins] (stacked real, imag)
|
86
|
+
"""
|
87
|
+
import numpy as np
|
121
88
|
|
89
|
+
from sonusai.utils import stack_complex
|
122
90
|
|
123
|
-
|
124
|
-
|
91
|
+
truth = np.empty((len(data.target_audio) // config.frame_size, config.target_fft.bins * 2), dtype=np.float32)
|
92
|
+
for idx, offset in enumerate(range(0, len(data.target_audio), config.frame_size)):
|
93
|
+
target_freq = config.target_fft.execute(
|
94
|
+
np.multiply(data.target_audio[offset : offset + config.frame_size], config.swin)
|
95
|
+
)[0]
|
96
|
+
truth[idx] = stack_complex(target_freq)
|
125
97
|
|
98
|
+
return truth
|
126
99
|
|
127
|
-
|
128
|
-
|
129
|
-
frame_size: int,
|
130
|
-
zero_based_indices: list[int],
|
131
|
-
bins: int,
|
132
|
-
ttype: str,
|
133
|
-
start: int,
|
134
|
-
truth: Truth) -> Truth:
|
100
|
+
|
101
|
+
def _stack_real_imag(data: AudioF, ttype: str) -> Truth:
|
135
102
|
import numpy as np
|
136
103
|
|
137
|
-
|
138
|
-
for index in zero_based_indices:
|
139
|
-
b = _get_bin_slice(index + start, bins)
|
140
|
-
truth[i, b] = np.real(data)
|
104
|
+
from sonusai.utils import stack_complex
|
141
105
|
|
142
|
-
|
143
|
-
|
144
|
-
truth[i, b] = np.imag(data)
|
106
|
+
if ttype == "tdac-co":
|
107
|
+
return np.real(data)
|
145
108
|
|
146
|
-
return
|
109
|
+
return stack_complex(data)
|
sonusai/mkwav.py
CHANGED
@@ -16,17 +16,18 @@ Inputs:
|
|
16
16
|
MIXID A glob of mixture ID(s) to generate.
|
17
17
|
|
18
18
|
Outputs the following to the mixture database directory:
|
19
|
-
<id>
|
20
|
-
|
21
|
-
|
22
|
-
|
19
|
+
<id>
|
20
|
+
mixture.wav: mixture
|
21
|
+
target.wav: target (optional)
|
22
|
+
noise.wav: noise (optional)
|
23
|
+
metadata.txt
|
23
24
|
mkwav.log
|
24
25
|
|
25
26
|
"""
|
27
|
+
|
26
28
|
import signal
|
27
29
|
from dataclasses import dataclass
|
28
30
|
|
29
|
-
from sonusai.mixture import AudioT
|
30
31
|
from sonusai.mixture import MixtureDatabase
|
31
32
|
|
32
33
|
|
@@ -35,7 +36,7 @@ def signal_handler(_sig, _frame):
|
|
35
36
|
|
36
37
|
from sonusai import logger
|
37
38
|
|
38
|
-
logger.info(
|
39
|
+
logger.info("Canceled due to keyboard interrupt")
|
39
40
|
sys.exit(1)
|
40
41
|
|
41
42
|
|
@@ -44,70 +45,37 @@ signal.signal(signal.SIGINT, signal_handler)
|
|
44
45
|
|
45
46
|
@dataclass
|
46
47
|
class MPGlobal:
|
47
|
-
mixdb: MixtureDatabase
|
48
|
-
write_target: bool
|
49
|
-
write_noise: bool
|
50
|
-
|
51
|
-
|
52
|
-
MP_GLOBAL = MPGlobal()
|
53
|
-
|
48
|
+
mixdb: MixtureDatabase
|
49
|
+
write_target: bool
|
50
|
+
write_noise: bool
|
54
51
|
|
55
|
-
def mkwav(location: str, mixid: int) -> tuple[AudioT, AudioT, AudioT]:
|
56
|
-
import numpy as np
|
57
52
|
|
58
|
-
|
53
|
+
MP_GLOBAL: MPGlobal
|
59
54
|
|
60
|
-
data = genmix(location=location, mixids=mixid, force=False)
|
61
55
|
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
def _process_mixture(mixid: int) -> None:
|
66
|
-
from os.path import exists
|
56
|
+
def _process_mixture(m_id: int) -> None:
|
67
57
|
from os.path import join
|
68
|
-
from os.path import splitext
|
69
|
-
|
70
|
-
import h5py
|
71
|
-
import numpy as np
|
72
58
|
|
73
|
-
from sonusai.mixture import
|
59
|
+
from sonusai.mixture import write_mixture_metadata
|
74
60
|
from sonusai.utils import float_to_int16
|
75
61
|
from sonusai.utils import write_audio
|
76
62
|
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
need_data = True
|
84
|
-
if exists(mixture_filename + '.h5'):
|
85
|
-
with h5py.File(mixture_filename, 'r') as f:
|
86
|
-
if 'mixture' in f:
|
87
|
-
need_data = False
|
88
|
-
if MP_GLOBAL.write_target and 'targets' not in f:
|
89
|
-
need_data = True
|
90
|
-
if MP_GLOBAL.write_noise and 'noise' not in f:
|
91
|
-
need_data = True
|
92
|
-
|
93
|
-
if need_data:
|
94
|
-
mixture, target, noise = mkwav(location=MP_GLOBAL.mixdb.location, mixid=mixid)
|
95
|
-
else:
|
96
|
-
with h5py.File(mixture_filename, 'r') as f:
|
97
|
-
mixture = np.array(f['mixture'])
|
98
|
-
if MP_GLOBAL.write_target:
|
99
|
-
target = np.sum(np.array(f['targets']), axis=0)
|
100
|
-
if MP_GLOBAL.write_noise:
|
101
|
-
noise = np.array(f['noise'])
|
102
|
-
|
103
|
-
write_audio(name=mixture_basename + '_mixture.wav', audio=float_to_int16(mixture))
|
104
|
-
if MP_GLOBAL.write_target:
|
105
|
-
write_audio(name=mixture_basename + '_target.wav', audio=float_to_int16(target))
|
106
|
-
if MP_GLOBAL.write_noise:
|
107
|
-
write_audio(name=mixture_basename + '_noise.wav', audio=float_to_int16(noise))
|
63
|
+
global MP_GLOBAL
|
64
|
+
|
65
|
+
mixdb = MP_GLOBAL.mixdb
|
66
|
+
write_target = MP_GLOBAL.write_target
|
67
|
+
write_noise = MP_GLOBAL.write_noise
|
108
68
|
|
109
|
-
|
110
|
-
|
69
|
+
mixture = mixdb.mixture(m_id)
|
70
|
+
location = join(mixdb.location, mixture.name)
|
71
|
+
|
72
|
+
write_audio(name=join(location, "mixture.wav"), audio=float_to_int16(mixdb.mixture_mixture(m_id)))
|
73
|
+
if write_target:
|
74
|
+
write_audio(name=join(location, "target.wav"), audio=float_to_int16(mixdb.mixture_target(m_id)))
|
75
|
+
if write_noise:
|
76
|
+
write_audio(name=join(location, "noise.wav"), audio=float_to_int16(mixdb.mixture_noise(m_id)))
|
77
|
+
|
78
|
+
write_mixture_metadata(mixdb, mixture)
|
111
79
|
|
112
80
|
|
113
81
|
def main() -> None:
|
@@ -118,63 +86,62 @@ def main() -> None:
|
|
118
86
|
|
119
87
|
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
120
88
|
|
121
|
-
verbose = args[
|
122
|
-
mixid = args[
|
123
|
-
MP_GLOBAL.write_target = args[
|
124
|
-
MP_GLOBAL.write_noise = args[
|
125
|
-
location = args[
|
89
|
+
verbose = args["--verbose"]
|
90
|
+
mixid = args["--mixid"]
|
91
|
+
MP_GLOBAL.write_target = args["--target"]
|
92
|
+
MP_GLOBAL.write_noise = args["--noise"]
|
93
|
+
location = args["LOC"]
|
126
94
|
|
127
95
|
import time
|
128
96
|
from os.path import join
|
129
97
|
|
130
|
-
from tqdm import tqdm
|
131
|
-
|
132
98
|
import sonusai
|
133
99
|
from sonusai import create_file_handler
|
134
100
|
from sonusai import initial_log_messages
|
135
101
|
from sonusai import logger
|
136
102
|
from sonusai import update_console_handler
|
137
103
|
from sonusai.mixture import check_audio_files_exist
|
138
|
-
from sonusai.utils import pp_tqdm_imap
|
139
104
|
from sonusai.utils import human_readable_size
|
105
|
+
from sonusai.utils import par_track
|
140
106
|
from sonusai.utils import seconds_to_hms
|
107
|
+
from sonusai.utils import track
|
141
108
|
|
142
109
|
start_time = time.monotonic()
|
143
110
|
|
144
|
-
create_file_handler(join(location,
|
111
|
+
create_file_handler(join(location, "mkwav.log"))
|
145
112
|
update_console_handler(verbose)
|
146
|
-
initial_log_messages(
|
113
|
+
initial_log_messages("mkwav")
|
147
114
|
|
148
|
-
logger.info(f
|
115
|
+
logger.info(f"Load mixture database from {location}")
|
149
116
|
MP_GLOBAL.mixdb = MixtureDatabase(location)
|
150
117
|
mixid = MP_GLOBAL.mixdb.mixids_to_list(mixid)
|
151
118
|
|
152
119
|
total_samples = MP_GLOBAL.mixdb.total_samples(mixid)
|
153
120
|
duration = total_samples / sonusai.mixture.SAMPLE_RATE
|
154
121
|
|
155
|
-
logger.info(
|
156
|
-
logger.info(f
|
157
|
-
logger.info(f
|
122
|
+
logger.info("")
|
123
|
+
logger.info(f"Found {len(mixid):,} mixtures to process")
|
124
|
+
logger.info(f"{total_samples:,} samples")
|
158
125
|
|
159
126
|
check_audio_files_exist(MP_GLOBAL.mixdb)
|
160
127
|
|
161
|
-
progress =
|
162
|
-
|
128
|
+
progress = track(total=len(mixid))
|
129
|
+
par_track(_process_mixture, mixid, progress=progress)
|
163
130
|
progress.close()
|
164
131
|
|
165
|
-
logger.info(f
|
166
|
-
logger.info(
|
167
|
-
logger.info(f
|
168
|
-
logger.info(f
|
132
|
+
logger.info(f"Wrote {len(mixid)} mixtures to {location}")
|
133
|
+
logger.info("")
|
134
|
+
logger.info(f"Duration: {seconds_to_hms(seconds=duration)}")
|
135
|
+
logger.info(f"mixture: {human_readable_size(total_samples * 2, 1)}")
|
169
136
|
if MP_GLOBAL.write_target:
|
170
|
-
logger.info(f
|
137
|
+
logger.info(f"target: {human_readable_size(total_samples * 2, 1)}")
|
171
138
|
if MP_GLOBAL.write_noise:
|
172
|
-
logger.info(f
|
139
|
+
logger.info(f"noise: {human_readable_size(total_samples * 2, 1)}")
|
173
140
|
|
174
141
|
end_time = time.monotonic()
|
175
|
-
logger.info(f
|
176
|
-
logger.info(
|
142
|
+
logger.info(f"Completed in {seconds_to_hms(seconds=end_time - start_time)}")
|
143
|
+
logger.info("")
|
177
144
|
|
178
145
|
|
179
|
-
if __name__ ==
|
146
|
+
if __name__ == "__main__":
|
180
147
|
main()
|