sonusai 0.18.9__py3-none-any.whl → 0.19.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +81 -91
- sonusai/genmetrics.py +51 -61
- sonusai/genmix.py +105 -115
- sonusai/genmixdb.py +201 -174
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +16 -18
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +20 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +40 -38
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +669 -477
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +58 -101
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +41 -30
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/METADATA +20 -21
- sonusai-0.19.6.dist-info/RECORD +125 -0
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.9.dist-info/RECORD +0 -125
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/entry_points.txt +0 -0
@@ -1,65 +1,91 @@
|
|
1
1
|
from sonusai.mixture.datatypes import Truth
|
2
|
-
from sonusai.mixture.truth_functions.
|
2
|
+
from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
|
3
|
+
from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
|
3
4
|
|
4
5
|
|
5
|
-
def
|
6
|
+
def _strictly_decreasing(list_to_check: list) -> bool:
|
7
|
+
from itertools import pairwise
|
8
|
+
|
9
|
+
return all(x > y for x, y in pairwise(list_to_check))
|
10
|
+
|
11
|
+
|
12
|
+
def sed_validate(config: dict) -> None:
|
13
|
+
if len(config) == 0:
|
14
|
+
raise AttributeError("sed truth function is missing config")
|
15
|
+
|
16
|
+
parameters = ["thresholds"]
|
17
|
+
for parameter in parameters:
|
18
|
+
if parameter not in config:
|
19
|
+
raise AttributeError(f"sed truth function is missing required '{parameter}'")
|
20
|
+
|
21
|
+
thresholds = config["thresholds"]
|
22
|
+
if not _strictly_decreasing(thresholds):
|
23
|
+
raise ValueError(f"sed truth function 'thresholds' are not strictly decreasing: {thresholds}")
|
24
|
+
|
25
|
+
|
26
|
+
def sed_parameters(config: TruthFunctionConfig) -> int:
|
27
|
+
return config.num_classes
|
28
|
+
|
29
|
+
|
30
|
+
def sed(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
6
31
|
"""Sound energy detection truth generation function
|
7
32
|
|
8
|
-
Calculates sound energy detection truth using simple 3 threshold
|
9
|
-
hysteresis algorithm. SED outputs 3 possible probabilities of
|
10
|
-
sound presence: 1.0 present, 0.5 (transition/uncertain), 0 not
|
11
|
-
present. The output values will be assigned to the truth output
|
12
|
-
at the index specified in the
|
33
|
+
Calculates sound energy detection truth using simple 3 threshold
|
34
|
+
hysteresis algorithm. SED outputs 3 possible probabilities of
|
35
|
+
sound presence: 1.0 present, 0.5 (transition/uncertain), 0 not
|
36
|
+
present. The output values will be assigned to the truth output
|
37
|
+
at the index specified in the config.
|
38
|
+
|
39
|
+
Output shape: [:, num_classes]
|
40
|
+
|
41
|
+
index Truth index <int> or list(<int>)
|
13
42
|
|
14
|
-
|
43
|
+
index indicates which truth fields should be set.
|
44
|
+
0 indicates none, 1 is first element in truth output vector, 2 2nd element, etc.
|
15
45
|
|
16
|
-
|
17
|
-
|
46
|
+
Examples:
|
47
|
+
index = 5 truth in class 5, truth(4, 1)
|
48
|
+
index = [1, 5] truth in classes 1 and 5, truth([0, 4], 1)
|
18
49
|
|
19
|
-
|
20
|
-
|
21
|
-
|
50
|
+
In mutually-exclusive mode, a frame is expected to only
|
51
|
+
belong to one class and thus all probabilities must sum to
|
52
|
+
1. This is effectively truth for a classifier with multichannel
|
53
|
+
softmax output.
|
54
|
+
|
55
|
+
For multi-label classification each class is an individual
|
56
|
+
probability for that class and any given frame can be
|
57
|
+
assigned to multiple classes/labels, i.e., the classes are
|
58
|
+
not mutually-exclusive. For example, a NN classifier with
|
59
|
+
multichannel sigmoid output. In this case, index could
|
60
|
+
also be a vector with multiple class indices.
|
22
61
|
"""
|
23
62
|
import numpy as np
|
24
63
|
import torch
|
25
64
|
from pyaaware import SED
|
26
65
|
|
27
|
-
|
28
|
-
|
29
|
-
if data.config.config is None:
|
30
|
-
raise SonusAIError('Truth function SED missing config')
|
66
|
+
if len(data.target_audio) % config.frame_size != 0:
|
67
|
+
raise ValueError(f"Number of samples in audio is not a multiple of {config.frame_size}")
|
31
68
|
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
thresholds = data.config.config['thresholds']
|
38
|
-
if not _strictly_decreasing(thresholds):
|
39
|
-
raise SonusAIError(f'Truth function SED thresholds are not strictly decreasing: {thresholds}')
|
40
|
-
|
41
|
-
if len(data.target_audio) % data.frame_size != 0:
|
42
|
-
raise SonusAIError(f'Number of samples in audio is not a multiple of {data.frame_size}')
|
69
|
+
frames = config.target_fft.frames(data.target_audio)
|
70
|
+
parameters = sed_parameters(config)
|
71
|
+
if config.target_gain == 0:
|
72
|
+
return np.zeros((frames, parameters), dtype=np.float32)
|
43
73
|
|
44
74
|
# SED wants 1-based indices
|
45
|
-
s = SED(
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
75
|
+
s = SED(
|
76
|
+
thresholds=config.config["thresholds"],
|
77
|
+
index=config.class_indices,
|
78
|
+
frame_size=config.frame_size,
|
79
|
+
num_classes=config.num_classes,
|
80
|
+
)
|
50
81
|
|
51
|
-
|
52
|
-
|
53
|
-
if len(energy_t) != len(data.offsets):
|
54
|
-
raise SonusAIError(f'Number of frames in energy_t, {len(energy_t)},'
|
55
|
-
f' is not number of frames in truth, {len(data.offsets)}')
|
82
|
+
# Back out target gain
|
83
|
+
target_audio = data.target_audio / config.target_gain
|
56
84
|
|
57
|
-
|
58
|
-
|
59
|
-
data.truth[offset:offset + data.frame_size] = np.reshape(new_truth, (1, len(new_truth)))
|
85
|
+
# Compute energy
|
86
|
+
target_energy = config.target_fft.execute_all(torch.from_numpy(target_audio))[1].numpy()
|
60
87
|
|
61
|
-
|
88
|
+
if frames != target_energy.shape[0]:
|
89
|
+
raise ValueError("Incorrect frames calculation in sed truth function")
|
62
90
|
|
63
|
-
|
64
|
-
def _strictly_decreasing(list_to_check: list) -> bool:
|
65
|
-
return all(x > y for x, y in zip(list_to_check, list_to_check[1:]))
|
91
|
+
return s.execute_all(target_energy)
|
@@ -1,146 +1,109 @@
|
|
1
|
-
from sonusai import ForwardTransform
|
2
|
-
|
3
1
|
from sonusai.mixture.datatypes import AudioF
|
4
|
-
from sonusai.mixture.datatypes import AudioT
|
5
2
|
from sonusai.mixture.datatypes import Truth
|
6
|
-
from sonusai.mixture.truth_functions.
|
7
|
-
|
8
|
-
|
9
|
-
def target_f(data: Data) -> Truth:
|
10
|
-
"""Frequency domain target truth function
|
3
|
+
from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
|
4
|
+
from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
|
11
5
|
|
12
|
-
Calculates the true transform of the target using the STFT
|
13
|
-
configuration defined by the feature. This will include a
|
14
|
-
forward transform window if defined by the feature.
|
15
6
|
|
16
|
-
|
17
|
-
|
18
|
-
"""
|
19
|
-
from sonusai import SonusAIError
|
7
|
+
def target_f_validate(_config: dict) -> None:
|
8
|
+
pass
|
20
9
|
|
21
|
-
if data.config.num_classes != data.feature_parameters:
|
22
|
-
raise SonusAIError(f'Invalid num_classes for target_f truth: {data.config.num_classes}')
|
23
10
|
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
offset=offset,
|
28
|
-
frame_size=data.frame_size,
|
29
|
-
zero_based_indices=data.zero_based_indices,
|
30
|
-
bins=data.target_fft.bins,
|
31
|
-
ttype=data.ttype,
|
32
|
-
start=0,
|
33
|
-
truth=data.truth)
|
11
|
+
def target_f_parameters(config: TruthFunctionConfig) -> int:
|
12
|
+
if config.ttype == "tdac-co":
|
13
|
+
return config.target_fft.bins
|
34
14
|
|
35
|
-
return
|
15
|
+
return config.target_fft.bins * 2
|
36
16
|
|
37
17
|
|
38
|
-
|
39
|
-
|
40
|
-
"""Frequency domain target and mixture truth function
|
18
|
+
def target_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
19
|
+
"""Frequency domain target truth function
|
41
20
|
|
42
|
-
Calculates the true transform of the target
|
43
|
-
|
44
|
-
|
45
|
-
feature.
|
21
|
+
Calculates the true transform of the target using the STFT
|
22
|
+
configuration defined by the feature. This will include a
|
23
|
+
forward transform window if defined by the feature.
|
46
24
|
|
47
|
-
Output shape: [:, 2 *
|
48
|
-
|
49
|
-
(mixture stacked real, imag; or real only for tdac-co)
|
25
|
+
Output shape: [:, 2 * bins] (target stacked real, imag) or
|
26
|
+
[:, bins] (target real only for tdac-co)
|
50
27
|
"""
|
51
|
-
|
28
|
+
import torch
|
52
29
|
|
53
|
-
|
54
|
-
|
30
|
+
target_freq = config.target_fft.execute_all(torch.from_numpy(data.target_audio))[0].numpy()
|
31
|
+
return _stack_real_imag(target_freq, config.ttype)
|
55
32
|
|
56
|
-
target_freq = _execute_fft(data.target_audio, data.target_fft, len(data.offsets))
|
57
|
-
mixture_freq = _execute_fft(data.mixture_audio, data.mixture_fft, len(data.offsets))
|
58
33
|
|
59
|
-
|
60
|
-
|
61
|
-
offset=offset,
|
62
|
-
frame_size=data.frame_size,
|
63
|
-
zero_based_indices=data.zero_based_indices,
|
64
|
-
bins=data.target_fft.bins,
|
65
|
-
ttype=data.ttype,
|
66
|
-
start=0,
|
67
|
-
truth=data.truth)
|
34
|
+
def target_mixture_f_validate(_config: dict) -> None:
|
35
|
+
pass
|
68
36
|
|
69
|
-
data.truth = _stack_real_imag(data=mixture_freq[idx],
|
70
|
-
offset=offset,
|
71
|
-
frame_size=data.frame_size,
|
72
|
-
zero_based_indices=data.zero_based_indices,
|
73
|
-
bins=data.target_fft.bins,
|
74
|
-
ttype=data.ttype,
|
75
|
-
start=data.target_fft.bins * 2,
|
76
|
-
truth=data.truth)
|
77
37
|
|
78
|
-
|
38
|
+
def target_mixture_f_parameters(config: TruthFunctionConfig) -> int:
|
39
|
+
if config.ttype == "tdac-co":
|
40
|
+
return config.target_fft.bins * 2
|
79
41
|
|
42
|
+
return config.target_fft.bins * 4
|
80
43
|
|
81
|
-
def target_swin_f(data: Data) -> Truth:
|
82
|
-
"""Frequency domain target with synthesis window truth function
|
83
44
|
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
the
|
45
|
+
def target_mixture_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
46
|
+
"""Frequency domain target and mixture truth function
|
47
|
+
|
48
|
+
Calculates the true transform of the target and the mixture
|
49
|
+
using the STFT configuration defined by the feature. This
|
50
|
+
will include a forward transform window if defined by the
|
51
|
+
feature.
|
88
52
|
|
89
|
-
Output shape: [:,
|
53
|
+
Output shape: [:, 4 * bins] (target stacked real, imag; mixture stacked real, imag) or
|
54
|
+
[:, 2 * bins] (target real; mixture real for tdac-co)
|
90
55
|
"""
|
91
56
|
import numpy as np
|
57
|
+
import torch
|
92
58
|
|
93
|
-
|
59
|
+
target_freq = config.target_fft.execute_all(torch.from_numpy(data.target_audio))[0].numpy()
|
60
|
+
mixture_freq = config.mixture_fft.execute_all(torch.from_numpy(data.mixture_audio))[0].numpy()
|
94
61
|
|
95
|
-
|
96
|
-
|
62
|
+
frames, bins = target_freq.shape
|
63
|
+
truth = np.empty((frames, bins * 4), dtype=np.float32)
|
64
|
+
truth[:, : bins * 2] = _stack_real_imag(target_freq, config.ttype)
|
65
|
+
truth[:, bins * 2 :] = _stack_real_imag(mixture_freq, config.ttype)
|
66
|
+
return truth
|
97
67
|
|
98
|
-
for idx, offset in enumerate(data.offsets):
|
99
|
-
target_freq, _ = data.target_fft.execute(
|
100
|
-
np.multiply(data.target_audio[offset:offset + data.frame_size], data.swin))
|
101
68
|
|
102
|
-
|
103
|
-
|
104
|
-
bins = _get_bin_slice(index, data.target_fft.bins)
|
105
|
-
data.truth[indices, bins] = np.real(target_freq[idx])
|
69
|
+
def target_swin_f_validate(_config: dict) -> None:
|
70
|
+
pass
|
106
71
|
|
107
|
-
bins = _get_bin_slice(bins.stop, data.target_fft.bins)
|
108
|
-
data.truth[indices, bins] = np.imag(target_freq[idx])
|
109
72
|
|
110
|
-
|
73
|
+
def target_swin_f_parameters(config: TruthFunctionConfig) -> int:
|
74
|
+
return config.target_fft.bins * 2
|
111
75
|
|
112
76
|
|
113
|
-
def
|
114
|
-
|
115
|
-
|
77
|
+
def target_swin_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
|
78
|
+
"""Frequency domain target with synthesis window truth function
|
79
|
+
|
80
|
+
Calculates the true transform of the target using the STFT
|
81
|
+
configuration defined by the feature. This will include a
|
82
|
+
forward transform window if defined by the feature and also
|
83
|
+
the inverse transform (or synthesis) window.
|
116
84
|
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
return freq
|
85
|
+
Output shape: [:, 2 * bins] (stacked real, imag)
|
86
|
+
"""
|
87
|
+
import numpy as np
|
121
88
|
|
89
|
+
from sonusai.utils import stack_complex
|
122
90
|
|
123
|
-
|
124
|
-
|
91
|
+
truth = np.empty((len(data.target_audio) // config.frame_size, config.target_fft.bins * 2), dtype=np.float32)
|
92
|
+
for idx, offset in enumerate(range(0, len(data.target_audio), config.frame_size)):
|
93
|
+
target_freq = config.target_fft.execute(
|
94
|
+
np.multiply(data.target_audio[offset : offset + config.frame_size], config.swin)
|
95
|
+
)[0]
|
96
|
+
truth[idx] = stack_complex(target_freq)
|
125
97
|
|
98
|
+
return truth
|
126
99
|
|
127
|
-
|
128
|
-
|
129
|
-
frame_size: int,
|
130
|
-
zero_based_indices: list[int],
|
131
|
-
bins: int,
|
132
|
-
ttype: str,
|
133
|
-
start: int,
|
134
|
-
truth: Truth) -> Truth:
|
100
|
+
|
101
|
+
def _stack_real_imag(data: AudioF, ttype: str) -> Truth:
|
135
102
|
import numpy as np
|
136
103
|
|
137
|
-
|
138
|
-
for index in zero_based_indices:
|
139
|
-
b = _get_bin_slice(index + start, bins)
|
140
|
-
truth[i, b] = np.real(data)
|
104
|
+
from sonusai.utils import stack_complex
|
141
105
|
|
142
|
-
|
143
|
-
|
144
|
-
truth[i, b] = np.imag(data)
|
106
|
+
if ttype == "tdac-co":
|
107
|
+
return np.real(data)
|
145
108
|
|
146
|
-
return
|
109
|
+
return stack_complex(data)
|
sonusai/mkwav.py
CHANGED
@@ -16,18 +16,16 @@ Inputs:
|
|
16
16
|
MIXID A glob of mixture ID(s) to generate.
|
17
17
|
|
18
18
|
Outputs the following to the mixture database directory:
|
19
|
-
<id>
|
20
|
-
|
21
|
-
|
22
|
-
|
19
|
+
<id>
|
20
|
+
mixture.wav: mixture
|
21
|
+
target.wav: target (optional)
|
22
|
+
noise.wav: noise (optional)
|
23
|
+
metadata.txt
|
23
24
|
mkwav.log
|
24
25
|
|
25
26
|
"""
|
26
|
-
import signal
|
27
|
-
from dataclasses import dataclass
|
28
27
|
|
29
|
-
|
30
|
-
from sonusai.mixture import MixtureDatabase
|
28
|
+
import signal
|
31
29
|
|
32
30
|
|
33
31
|
def signal_handler(_sig, _frame):
|
@@ -35,79 +33,33 @@ def signal_handler(_sig, _frame):
|
|
35
33
|
|
36
34
|
from sonusai import logger
|
37
35
|
|
38
|
-
logger.info(
|
36
|
+
logger.info("Canceled due to keyboard interrupt")
|
39
37
|
sys.exit(1)
|
40
38
|
|
41
39
|
|
42
40
|
signal.signal(signal.SIGINT, signal_handler)
|
43
41
|
|
44
42
|
|
45
|
-
|
46
|
-
class MPGlobal:
|
47
|
-
mixdb: MixtureDatabase = None
|
48
|
-
write_target: bool = None
|
49
|
-
write_noise: bool = None
|
50
|
-
|
51
|
-
|
52
|
-
MP_GLOBAL = MPGlobal()
|
53
|
-
|
54
|
-
|
55
|
-
def mkwav(location: str, mixid: int) -> tuple[AudioT, AudioT, AudioT]:
|
56
|
-
import numpy as np
|
57
|
-
|
58
|
-
from sonusai.genmix import genmix
|
59
|
-
|
60
|
-
data = genmix(location=location, mixids=mixid, force=False)
|
61
|
-
|
62
|
-
return data[0].mixture, np.sum(data[0].targets, axis=0), data[0].noise
|
63
|
-
|
64
|
-
|
65
|
-
def _process_mixture(mixid: int) -> None:
|
66
|
-
from os.path import exists
|
43
|
+
def _process_mixture(m_id: int, location: str, write_target: bool, write_noise: bool) -> None:
|
67
44
|
from os.path import join
|
68
|
-
from os.path import splitext
|
69
|
-
|
70
|
-
import h5py
|
71
|
-
import numpy as np
|
72
45
|
|
73
|
-
from sonusai.mixture import
|
46
|
+
from sonusai.mixture import MixtureDatabase
|
47
|
+
from sonusai.mixture import write_mixture_metadata
|
74
48
|
from sonusai.utils import float_to_int16
|
75
49
|
from sonusai.utils import write_audio
|
76
50
|
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
need_data = True
|
90
|
-
if MP_GLOBAL.write_noise and 'noise' not in f:
|
91
|
-
need_data = True
|
92
|
-
|
93
|
-
if need_data:
|
94
|
-
mixture, target, noise = mkwav(location=MP_GLOBAL.mixdb.location, mixid=mixid)
|
95
|
-
else:
|
96
|
-
with h5py.File(mixture_filename, 'r') as f:
|
97
|
-
mixture = np.array(f['mixture'])
|
98
|
-
if MP_GLOBAL.write_target:
|
99
|
-
target = np.sum(np.array(f['targets']), axis=0)
|
100
|
-
if MP_GLOBAL.write_noise:
|
101
|
-
noise = np.array(f['noise'])
|
102
|
-
|
103
|
-
write_audio(name=mixture_basename + '_mixture.wav', audio=float_to_int16(mixture))
|
104
|
-
if MP_GLOBAL.write_target:
|
105
|
-
write_audio(name=mixture_basename + '_target.wav', audio=float_to_int16(target))
|
106
|
-
if MP_GLOBAL.write_noise:
|
107
|
-
write_audio(name=mixture_basename + '_noise.wav', audio=float_to_int16(noise))
|
108
|
-
|
109
|
-
with open(file=mixture_basename + '.txt', mode='w') as f:
|
110
|
-
f.write(mixture_metadata(MP_GLOBAL.mixdb, MP_GLOBAL.mixdb.mixture(mixid)))
|
51
|
+
mixdb = MixtureDatabase(location)
|
52
|
+
|
53
|
+
mixture = mixdb.mixture(m_id)
|
54
|
+
location = join(mixdb.location, mixture.name)
|
55
|
+
|
56
|
+
write_audio(name=join(location, "mixture.wav"), audio=float_to_int16(mixdb.mixture_mixture(m_id)))
|
57
|
+
if write_target:
|
58
|
+
write_audio(name=join(location, "target.wav"), audio=float_to_int16(mixdb.mixture_target(m_id)))
|
59
|
+
if write_noise:
|
60
|
+
write_audio(name=join(location, "noise.wav"), audio=float_to_int16(mixdb.mixture_noise(m_id)))
|
61
|
+
|
62
|
+
write_mixture_metadata(mixdb, mixture)
|
111
63
|
|
112
64
|
|
113
65
|
def main() -> None:
|
@@ -118,63 +70,68 @@ def main() -> None:
|
|
118
70
|
|
119
71
|
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
120
72
|
|
121
|
-
verbose = args[
|
122
|
-
mixid = args[
|
123
|
-
|
124
|
-
|
125
|
-
location = args[
|
73
|
+
verbose = args["--verbose"]
|
74
|
+
mixid = args["--mixid"]
|
75
|
+
write_target = args["--target"]
|
76
|
+
write_noise = args["--noise"]
|
77
|
+
location = args["LOC"]
|
126
78
|
|
127
79
|
import time
|
80
|
+
from functools import partial
|
128
81
|
from os.path import join
|
129
82
|
|
130
|
-
from tqdm import tqdm
|
131
|
-
|
132
83
|
import sonusai
|
133
84
|
from sonusai import create_file_handler
|
134
85
|
from sonusai import initial_log_messages
|
135
86
|
from sonusai import logger
|
136
87
|
from sonusai import update_console_handler
|
88
|
+
from sonusai.mixture import MixtureDatabase
|
137
89
|
from sonusai.mixture import check_audio_files_exist
|
138
|
-
from sonusai.utils import pp_tqdm_imap
|
139
90
|
from sonusai.utils import human_readable_size
|
91
|
+
from sonusai.utils import par_track
|
140
92
|
from sonusai.utils import seconds_to_hms
|
93
|
+
from sonusai.utils import track
|
141
94
|
|
142
95
|
start_time = time.monotonic()
|
143
96
|
|
144
|
-
create_file_handler(join(location,
|
97
|
+
create_file_handler(join(location, "mkwav.log"))
|
145
98
|
update_console_handler(verbose)
|
146
|
-
initial_log_messages(
|
99
|
+
initial_log_messages("mkwav")
|
147
100
|
|
148
|
-
logger.info(f
|
149
|
-
|
150
|
-
mixid =
|
101
|
+
logger.info(f"Load mixture database from {location}")
|
102
|
+
mixdb = MixtureDatabase(location)
|
103
|
+
mixid = mixdb.mixids_to_list(mixid)
|
151
104
|
|
152
|
-
total_samples =
|
105
|
+
total_samples = mixdb.total_samples(mixid)
|
153
106
|
duration = total_samples / sonusai.mixture.SAMPLE_RATE
|
154
107
|
|
155
|
-
logger.info(
|
156
|
-
logger.info(f
|
157
|
-
logger.info(f
|
108
|
+
logger.info("")
|
109
|
+
logger.info(f"Found {len(mixid):,} mixtures to process")
|
110
|
+
logger.info(f"{total_samples:,} samples")
|
158
111
|
|
159
|
-
check_audio_files_exist(
|
112
|
+
check_audio_files_exist(mixdb)
|
160
113
|
|
161
|
-
progress =
|
162
|
-
|
114
|
+
progress = track(total=len(mixid))
|
115
|
+
par_track(
|
116
|
+
partial(_process_mixture, location=location, write_target=write_target, write_noise=write_noise),
|
117
|
+
mixid,
|
118
|
+
progress=progress,
|
119
|
+
)
|
163
120
|
progress.close()
|
164
121
|
|
165
|
-
logger.info(f
|
166
|
-
logger.info(
|
167
|
-
logger.info(f
|
168
|
-
logger.info(f
|
169
|
-
if
|
170
|
-
logger.info(f
|
171
|
-
if
|
172
|
-
logger.info(f
|
122
|
+
logger.info(f"Wrote {len(mixid)} mixtures to {location}")
|
123
|
+
logger.info("")
|
124
|
+
logger.info(f"Duration: {seconds_to_hms(seconds=duration)}")
|
125
|
+
logger.info(f"mixture: {human_readable_size(total_samples * 2, 1)}")
|
126
|
+
if write_target:
|
127
|
+
logger.info(f"target: {human_readable_size(total_samples * 2, 1)}")
|
128
|
+
if write_noise:
|
129
|
+
logger.info(f"noise: {human_readable_size(total_samples * 2, 1)}")
|
173
130
|
|
174
131
|
end_time = time.monotonic()
|
175
|
-
logger.info(f
|
176
|
-
logger.info(
|
132
|
+
logger.info(f"Completed in {seconds_to_hms(seconds=end_time - start_time)}")
|
133
|
+
logger.info("")
|
177
134
|
|
178
135
|
|
179
|
-
if __name__ ==
|
136
|
+
if __name__ == "__main__":
|
180
137
|
main()
|