sonusai 0.15.8__py3-none-any.whl → 0.16.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +35 -4
- sonusai/audiofe.py +237 -0
- sonusai/calc_metric_spenh.py +21 -12
- sonusai/genft.py +2 -1
- sonusai/genmixdb.py +5 -5
- sonusai/lsdb.py +2 -2
- sonusai/main.py +58 -61
- sonusai/mixture/__init__.py +4 -2
- sonusai/mixture/audio.py +0 -34
- sonusai/mixture/config.py +1 -2
- sonusai/mixture/datatypes.py +1 -1
- sonusai/mixture/feature.py +75 -21
- sonusai/mixture/helpers.py +60 -30
- sonusai/mixture/log_duration_and_sizes.py +2 -2
- sonusai/mixture/mixdb.py +13 -10
- sonusai/mixture/spectral_mask.py +14 -14
- sonusai/mixture/truth_functions/data.py +1 -1
- sonusai/mixture/truth_functions/target.py +2 -2
- sonusai/mkmanifest.py +29 -2
- sonusai/onnx_predict.py +1 -1
- sonusai/plot.py +4 -4
- sonusai/post_spenh_targetf.py +8 -8
- sonusai/utils/__init__.py +8 -7
- sonusai/utils/asl_p56.py +3 -3
- sonusai/utils/asr.py +35 -8
- sonusai/utils/asr_functions/__init__.py +0 -5
- sonusai/utils/asr_functions/aaware_whisper.py +2 -2
- sonusai/utils/asr_manifest_functions/__init__.py +1 -0
- sonusai/utils/asr_manifest_functions/mcgill_speech.py +29 -0
- sonusai/utils/audio_devices.py +41 -0
- sonusai/utils/calculate_input_shape.py +3 -4
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/{trim_docstring.py → docstring.py} +20 -0
- sonusai/utils/model_utils.py +30 -0
- sonusai/utils/onnx_utils.py +19 -45
- sonusai/utils/reshape.py +11 -11
- sonusai/utils/wave.py +12 -5
- {sonusai-0.15.8.dist-info → sonusai-0.16.0.dist-info}/METADATA +8 -19
- {sonusai-0.15.8.dist-info → sonusai-0.16.0.dist-info}/RECORD +41 -54
- {sonusai-0.15.8.dist-info → sonusai-0.16.0.dist-info}/WHEEL +1 -1
- sonusai/data_generator/__init__.py +0 -5
- sonusai/data_generator/dataset_from_mixdb.py +0 -143
- sonusai/data_generator/keras_from_mixdb.py +0 -169
- sonusai/data_generator/torch_from_mixdb.py +0 -122
- sonusai/evaluate.py +0 -245
- sonusai/keras_onnx.py +0 -86
- sonusai/keras_predict.py +0 -231
- sonusai/keras_train.py +0 -334
- sonusai/torchl_onnx.py +0 -216
- sonusai/torchl_predict.py +0 -547
- sonusai/torchl_train.py +0 -223
- sonusai/utils/asr_functions/aixplain_whisper.py +0 -59
- sonusai/utils/asr_functions/data.py +0 -16
- sonusai/utils/asr_functions/deepgram.py +0 -97
- sonusai/utils/asr_functions/fastwhisper.py +0 -90
- sonusai/utils/asr_functions/google.py +0 -95
- sonusai/utils/asr_functions/whisper.py +0 -49
- sonusai/utils/keras_utils.py +0 -226
- {sonusai-0.15.8.dist-info → sonusai-0.16.0.dist-info}/entry_points.txt +0 -0
sonusai/__init__.py
CHANGED
@@ -5,6 +5,24 @@ from os.path import dirname
|
|
5
5
|
__version__ = metadata.version(__package__)
|
6
6
|
BASEDIR = dirname(__file__)
|
7
7
|
|
8
|
+
commands_doc = """
|
9
|
+
audiofe Audio front end
|
10
|
+
calc_metric_spenh Run speech enhancement and analysis
|
11
|
+
doc Documentation
|
12
|
+
genft Generate feature and truth data
|
13
|
+
genmix Generate mixture and truth data
|
14
|
+
genmixdb Generate a mixture database
|
15
|
+
gentcst Generate target configuration from a subdirectory tree
|
16
|
+
lsdb List information about a mixture database
|
17
|
+
mkmanifest Make ASR manifest JSON file
|
18
|
+
mkwav Make WAV files from a mixture database
|
19
|
+
onnx_predict Run ONNX predict on a trained model
|
20
|
+
plot Plot mixture data
|
21
|
+
post_spenh_targetf Run post-processing for speech enhancement targetf data
|
22
|
+
tplot Plot truth data
|
23
|
+
vars List custom SonusAI variables
|
24
|
+
"""
|
25
|
+
|
8
26
|
# create logger
|
9
27
|
logger = logging.getLogger('sonusai')
|
10
28
|
logger.setLevel(logging.DEBUG)
|
@@ -21,7 +39,7 @@ class SonusAIError(Exception):
|
|
21
39
|
|
22
40
|
|
23
41
|
# create file handler
|
24
|
-
def create_file_handler(filename: str):
|
42
|
+
def create_file_handler(filename: str) -> None:
|
25
43
|
fh = logging.FileHandler(filename=filename, mode='w')
|
26
44
|
fh.setLevel(logging.DEBUG)
|
27
45
|
fh.setFormatter(formatter)
|
@@ -29,7 +47,7 @@ def create_file_handler(filename: str):
|
|
29
47
|
|
30
48
|
|
31
49
|
# update console handler
|
32
|
-
def update_console_handler(verbose: bool):
|
50
|
+
def update_console_handler(verbose: bool) -> None:
|
33
51
|
if not verbose:
|
34
52
|
logger.removeHandler(console_handler)
|
35
53
|
console_handler.setLevel(logging.INFO)
|
@@ -37,14 +55,17 @@ def update_console_handler(verbose: bool):
|
|
37
55
|
|
38
56
|
|
39
57
|
# write initial log message
|
40
|
-
def initial_log_messages(name: str):
|
58
|
+
def initial_log_messages(name: str, subprocess: str = None) -> None:
|
41
59
|
from datetime import datetime
|
42
60
|
from getpass import getuser
|
43
61
|
from os import getcwd
|
44
62
|
from socket import gethostname
|
45
63
|
from sys import argv
|
46
64
|
|
47
|
-
|
65
|
+
if subprocess is None:
|
66
|
+
logger.info(f'SonusAI {__version__}')
|
67
|
+
else:
|
68
|
+
logger.info(f'SonusAI {subprocess}')
|
48
69
|
logger.info(f'{name}')
|
49
70
|
logger.info('')
|
50
71
|
logger.debug(f'Host: {gethostname()}')
|
@@ -53,3 +74,13 @@ def initial_log_messages(name: str):
|
|
53
74
|
logger.debug(f'Date: {datetime.now()}')
|
54
75
|
logger.debug(f'Command: {" ".join(argv)}')
|
55
76
|
logger.debug('')
|
77
|
+
|
78
|
+
|
79
|
+
def commands_list(doc: str = commands_doc) -> list[str]:
|
80
|
+
lines = doc.split('\n')
|
81
|
+
commands = []
|
82
|
+
for line in lines:
|
83
|
+
command = line.strip().split(' ').pop(0)
|
84
|
+
if command:
|
85
|
+
commands.append(command)
|
86
|
+
return commands
|
sonusai/audiofe.py
ADDED
@@ -0,0 +1,237 @@
|
|
1
|
+
"""sonusai audiofe
|
2
|
+
|
3
|
+
usage: audiofe [-hvds] [--version] [-i INPUT] [-l LENGTH] [-m MODEL] [-k CKPT] [-a ASR] [-w WMODEL]
|
4
|
+
|
5
|
+
options:
|
6
|
+
-h, --help
|
7
|
+
-v, --verbose Be verbose.
|
8
|
+
-d, --debug Write debug data to H5 file.
|
9
|
+
-s, --show Show a list of available audio inputs.
|
10
|
+
-i INPUT, --input INPUT Input audio.
|
11
|
+
-l LENGTH, --length LENGTH Length of audio in seconds. [default: -1].
|
12
|
+
-m MODEL, --model MODEL PL model .py file path.
|
13
|
+
-k CKPT, --checkpoint CKPT PL checkpoint file with weights.
|
14
|
+
-a ASR, --asr ASR ASR method to use.
|
15
|
+
-w WMODEL, --whisper WMODEL Whisper model used in aixplain_whisper and whisper methods. [default: tiny].
|
16
|
+
|
17
|
+
Aaware SonusAI Audio Front End.
|
18
|
+
|
19
|
+
Capture LENGTH seconds of audio from INPUT. If LENGTH is < 0, then capture until key is pressed. If INPUT is a valid
|
20
|
+
audio file name, then use the audio data from the specified file. In this case, if LENGTH is < 0, process entire file;
|
21
|
+
otherwise, process min(length(INPUT), LENGTH) seconds of audio from INPUT. Audio is saved to
|
22
|
+
audiofe_capture_<TIMESTAMP>.wav.
|
23
|
+
|
24
|
+
If a model is specified, run prediction on audio data from this model. Then compute the inverse transform of the
|
25
|
+
prediction result and save to audiofe_predict_<TIMESTAMP>.wav.
|
26
|
+
|
27
|
+
If an ASR is specified, run ASR on the captured audio and print the results. In addition, if a model was also specified,
|
28
|
+
run ASR on the predict audio and print the results.
|
29
|
+
|
30
|
+
If the debug option is enabled, write capture audio, feature, reconstruct audio, predict, and predict audio to
|
31
|
+
audiofe_<TIMESTAMP>.h5.
|
32
|
+
|
33
|
+
"""
|
34
|
+
from os.path import exists
|
35
|
+
from select import select
|
36
|
+
from sys import stdin
|
37
|
+
|
38
|
+
import h5py
|
39
|
+
import numpy as np
|
40
|
+
import pyaudio
|
41
|
+
import torch
|
42
|
+
from docopt import docopt
|
43
|
+
from docopt import printable_usage
|
44
|
+
|
45
|
+
import sonusai
|
46
|
+
from sonusai import create_file_handler
|
47
|
+
from sonusai import initial_log_messages
|
48
|
+
from sonusai import logger
|
49
|
+
from sonusai import update_console_handler
|
50
|
+
from sonusai.mixture import AudioT
|
51
|
+
from sonusai.mixture import CHANNEL_COUNT
|
52
|
+
from sonusai.mixture import SAMPLE_RATE
|
53
|
+
from sonusai.mixture import get_audio_from_feature
|
54
|
+
from sonusai.mixture import get_feature_from_audio
|
55
|
+
from sonusai.mixture import read_audio
|
56
|
+
from sonusai.utils import calc_asr
|
57
|
+
from sonusai.utils import create_timestamp
|
58
|
+
from sonusai.utils import get_input_device_index_by_name
|
59
|
+
from sonusai.utils import get_input_devices
|
60
|
+
from sonusai.utils import load_torchl_ckpt_model
|
61
|
+
from sonusai.utils import trim_docstring
|
62
|
+
from sonusai.utils import write_wav
|
63
|
+
|
64
|
+
|
65
|
+
def main() -> None:
|
66
|
+
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
67
|
+
ts = create_timestamp()
|
68
|
+
|
69
|
+
verbose = args['--verbose']
|
70
|
+
length = float(args['--length'])
|
71
|
+
input_name = args['--input']
|
72
|
+
model_name = args['--model']
|
73
|
+
ckpt_name = args['--checkpoint']
|
74
|
+
asr_name = args['--asr']
|
75
|
+
whisper_name = args['--whisper']
|
76
|
+
debug = args['--debug']
|
77
|
+
show = args['--show']
|
78
|
+
|
79
|
+
capture_name = f'audiofe_capture_{ts}.wav'
|
80
|
+
predict_name = f'audiofe_predict_{ts}.wav'
|
81
|
+
h5_name = f'audiofe_{ts}.h5'
|
82
|
+
|
83
|
+
if model_name is not None and ckpt_name is None:
|
84
|
+
print(printable_usage(trim_docstring(__doc__)))
|
85
|
+
exit(1)
|
86
|
+
|
87
|
+
# Setup logging file
|
88
|
+
create_file_handler('audiofe.log')
|
89
|
+
update_console_handler(verbose)
|
90
|
+
initial_log_messages('audiofe')
|
91
|
+
|
92
|
+
if show:
|
93
|
+
logger.info('List of available audio inputs:')
|
94
|
+
logger.info('')
|
95
|
+
p = pyaudio.PyAudio()
|
96
|
+
for name in get_input_devices(p):
|
97
|
+
logger.info(f'{name}')
|
98
|
+
logger.info('')
|
99
|
+
p.terminate()
|
100
|
+
return
|
101
|
+
|
102
|
+
if input_name is not None and exists(input_name):
|
103
|
+
capture_audio = get_frames_from_file(input_name, length)
|
104
|
+
else:
|
105
|
+
try:
|
106
|
+
capture_audio = get_frames_from_device(input_name, length)
|
107
|
+
except ValueError as e:
|
108
|
+
logger.exception(e)
|
109
|
+
return
|
110
|
+
|
111
|
+
write_wav(capture_name, capture_audio, SAMPLE_RATE)
|
112
|
+
logger.info('')
|
113
|
+
logger.info(f'Wrote capture audio with shape {capture_audio.shape} to {capture_name}')
|
114
|
+
if debug:
|
115
|
+
with h5py.File(h5_name, 'a') as f:
|
116
|
+
if 'capture_audio' in f:
|
117
|
+
del f['capture_audio']
|
118
|
+
f.create_dataset('capture_audio', data=capture_audio)
|
119
|
+
logger.info(f'Wrote capture audio with shape {capture_audio.shape} to {h5_name}')
|
120
|
+
|
121
|
+
if asr_name is not None:
|
122
|
+
capture_asr = calc_asr(capture_audio, engine=asr_name, whisper_model_name=whisper_name).text
|
123
|
+
logger.info(f'Capture audio ASR: {capture_asr}')
|
124
|
+
|
125
|
+
if model_name is not None:
|
126
|
+
model = load_torchl_ckpt_model(model_name=model_name, ckpt_name=ckpt_name)
|
127
|
+
model.eval()
|
128
|
+
|
129
|
+
feature = get_feature_from_audio(audio=capture_audio, feature_mode=model.hparams.feature)
|
130
|
+
if debug:
|
131
|
+
with h5py.File(h5_name, 'a') as f:
|
132
|
+
if 'feature' in f:
|
133
|
+
del f['feature']
|
134
|
+
f.create_dataset('feature', data=feature)
|
135
|
+
logger.info(f'Wrote feature with shape {feature.shape} to {h5_name}')
|
136
|
+
|
137
|
+
# if debug:
|
138
|
+
# reconstruct_name = f'audiofe_reconstruct_{ts}.wav'
|
139
|
+
# reconstruct_audio = get_audio_from_feature(feature=feature, feature_mode=model.hparams.feature)
|
140
|
+
# samples = min(len(capture_audio), len(reconstruct_audio))
|
141
|
+
# max_err = np.max(np.abs(capture_audio[:samples] - reconstruct_audio[:samples]))
|
142
|
+
# logger.info(f'Maximum error between capture and reconstruct: {max_err}')
|
143
|
+
# write_wav(reconstruct_name, reconstruct_audio, SAMPLE_RATE)
|
144
|
+
# logger.info(f'Wrote reconstruct audio with shape {reconstruct_audio.shape} to {reconstruct_name}')
|
145
|
+
# with h5py.File(h5_name, 'a') as f:
|
146
|
+
# if 'reconstruct_audio' in f:
|
147
|
+
# del f['reconstruct_audio']
|
148
|
+
# f.create_dataset('reconstruct_audio', data=reconstruct_audio)
|
149
|
+
# logger.info(f'Wrote reconstruct audio with shape {reconstruct_audio.shape} to {h5_name}')
|
150
|
+
|
151
|
+
with torch.no_grad():
|
152
|
+
# model wants batch x timesteps x feature_parameters
|
153
|
+
predict = model(torch.tensor(feature).permute((1, 0, 2))).permute(1, 0, 2).numpy()
|
154
|
+
if debug:
|
155
|
+
with h5py.File(h5_name, 'a') as f:
|
156
|
+
if 'predict' in f:
|
157
|
+
del f['predict']
|
158
|
+
f.create_dataset('predict', data=predict)
|
159
|
+
logger.info(f'Wrote predict with shape {predict.shape} to {h5_name}')
|
160
|
+
|
161
|
+
predict_audio = get_audio_from_feature(feature=predict, feature_mode=model.hparams.feature)
|
162
|
+
write_wav(predict_name, predict_audio, SAMPLE_RATE)
|
163
|
+
logger.info(f'Wrote predict audio with shape {predict_audio.shape} to {predict_name}')
|
164
|
+
if debug:
|
165
|
+
with h5py.File(h5_name, 'a') as f:
|
166
|
+
if 'predict_audio' in f:
|
167
|
+
del f['predict_audio']
|
168
|
+
f.create_dataset('predict_audio', data=predict_audio)
|
169
|
+
logger.info(f'Wrote predict audio with shape {predict_audio.shape} to {h5_name}')
|
170
|
+
|
171
|
+
if asr_name is not None:
|
172
|
+
predict_asr = calc_asr(predict_audio, engine=asr_name, whisper_model_name=whisper_name).text
|
173
|
+
logger.info(f'Predict audio ASR: {predict_asr}')
|
174
|
+
|
175
|
+
|
176
|
+
def get_frames_from_device(input_name: str | None, length: float, chunk: int = 1024) -> AudioT:
|
177
|
+
p = pyaudio.PyAudio()
|
178
|
+
|
179
|
+
input_devices = get_input_devices(p)
|
180
|
+
if not input_devices:
|
181
|
+
raise ValueError('No input audio devices found')
|
182
|
+
|
183
|
+
if input_name is None:
|
184
|
+
input_name = input_devices[0]
|
185
|
+
|
186
|
+
try:
|
187
|
+
device_index = get_input_device_index_by_name(p, input_name)
|
188
|
+
except ValueError:
|
189
|
+
msg = f'Could not find {input_name}\n'
|
190
|
+
msg += f'Available devices:\n'
|
191
|
+
for input_device in input_devices:
|
192
|
+
msg += f' {input_device}\n'
|
193
|
+
raise ValueError(msg)
|
194
|
+
|
195
|
+
logger.info(f'Capturing from {p.get_device_info_by_index(device_index).get("name")}')
|
196
|
+
stream = p.open(format=pyaudio.paFloat32,
|
197
|
+
channels=CHANNEL_COUNT,
|
198
|
+
rate=SAMPLE_RATE,
|
199
|
+
input=True,
|
200
|
+
input_device_index=device_index)
|
201
|
+
stream.start_stream()
|
202
|
+
|
203
|
+
print()
|
204
|
+
print('+---------------------------------+')
|
205
|
+
print('| Press Enter to stop |')
|
206
|
+
print('+---------------------------------+')
|
207
|
+
print()
|
208
|
+
|
209
|
+
elapsed = 0.0
|
210
|
+
seconds_per_chunk = float(chunk) / float(SAMPLE_RATE)
|
211
|
+
raw_frames = []
|
212
|
+
while elapsed < length or length == -1:
|
213
|
+
raw_frames.append(stream.read(num_frames=chunk, exception_on_overflow=False))
|
214
|
+
elapsed += seconds_per_chunk
|
215
|
+
if select([stdin, ], [], [], 0)[0]:
|
216
|
+
stdin.read(1)
|
217
|
+
length = elapsed
|
218
|
+
|
219
|
+
stream.stop_stream()
|
220
|
+
stream.close()
|
221
|
+
p.terminate()
|
222
|
+
frames = np.frombuffer(b''.join(raw_frames), dtype=np.float32)
|
223
|
+
return frames
|
224
|
+
|
225
|
+
|
226
|
+
def get_frames_from_file(input_name: str, length: float) -> AudioT:
|
227
|
+
logger.info(f'Capturing from {input_name}')
|
228
|
+
frames = read_audio(input_name)
|
229
|
+
if length != -1:
|
230
|
+
num_frames = int(length * SAMPLE_RATE)
|
231
|
+
if len(frames) > num_frames:
|
232
|
+
frames = frames[:num_frames]
|
233
|
+
return frames
|
234
|
+
|
235
|
+
|
236
|
+
if __name__ == '__main__':
|
237
|
+
main()
|
sonusai/calc_metric_spenh.py
CHANGED
@@ -758,13 +758,18 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
758
758
|
predict = stack_complex(predict)
|
759
759
|
|
760
760
|
# 2) Collect true target, noise, mixture data, trim to predict size if needed
|
761
|
-
|
762
|
-
target_f = mixdb.
|
763
|
-
|
764
|
-
|
765
|
-
|
761
|
+
tmp = mixdb.mixture_targets(mixid) # targets is list of pre-IR and pre-specaugment targets
|
762
|
+
target_f = mixdb.mixture_targets_f(mixid, targets=tmp)[0]
|
763
|
+
target = tmp[0]
|
764
|
+
mixture = mixdb.mixture_mixture(mixid) # note: gives full reverberated/distorted target, but no specaugment
|
765
|
+
# noise_wodist = mixdb.mixture_noise(mixid) # noise without specaugment and distortion
|
766
|
+
# noise_wodist_f = mixdb.mixture_noise_f(mixid, noise=noise_wodist)
|
767
|
+
noise = mixture - target # has time-domain distortion (ir,etc.) but does not have specaugment
|
768
|
+
# noise_f = mixdb.mixture_noise_f(mixid, noise=noise)
|
769
|
+
segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise) # note: uses pre-IR, pre-specaug audio
|
766
770
|
mixture_f = mixdb.mixture_mixture_f(mixid, mixture=mixture)
|
767
|
-
|
771
|
+
noise_f = mixture_f - target_f # true noise in freq domain includes specaugment and time-domain ir,distortions
|
772
|
+
# segsnr_f = mixdb.mixture_segsnr(mixid, target=target, noise=noise)
|
768
773
|
segsnr_f[segsnr_f == inf] = 7.944e8 # 99db
|
769
774
|
segsnr_f[segsnr_f == -inf] = 1.258e-10 # -99db
|
770
775
|
# need to use inv-tf to match #samples & latency shift properties of predict inv tf
|
@@ -920,8 +925,9 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
|
|
920
925
|
'NLERR': lerr_n_frame,
|
921
926
|
'SPD': phd_frame})
|
922
927
|
metr2 = metr2.describe() # Use pandas stat function
|
923
|
-
|
924
|
-
|
928
|
+
# Change SSNR stats to dB, except count. SSNR is index 0, pandas requires using iloc
|
929
|
+
# metr2['SSNR'][1:] = metr2['SSNR'][1:].apply(lambda x: 10 * np.log10(x + 1.01e-10))
|
930
|
+
metr2.iloc[1:, 0] = metr2['SSNR'][1:].apply(lambda x: 10 * np.log10(x + 1.01e-10))
|
925
931
|
# create a single row in multi-column header
|
926
932
|
new_labels = pd.MultiIndex.from_product([metr2.columns,
|
927
933
|
['Avg', 'Min', 'Med', 'Max', 'Std']],
|
@@ -1166,7 +1172,7 @@ def main():
|
|
1166
1172
|
# Individual mixtures use pandas print, set precision to 2 decimal places
|
1167
1173
|
# pd.set_option('float_format', '{:.2f}'.format)
|
1168
1174
|
progress = tqdm(total=len(mixids), desc='calc_metric_spenh')
|
1169
|
-
all_metrics_tables = pp_tqdm_imap(_process_mixture, mixids, progress=progress, num_cpus=
|
1175
|
+
all_metrics_tables = pp_tqdm_imap(_process_mixture, mixids, progress=progress, num_cpus=8)
|
1170
1176
|
progress.close()
|
1171
1177
|
|
1172
1178
|
all_metrics_table_1 = pd.concat([item[0] for item in all_metrics_tables])
|
@@ -1192,6 +1198,7 @@ def main():
|
|
1192
1198
|
if ~np.isnan(tmp.iloc[0].to_numpy()[0]).any():
|
1193
1199
|
mtab_snr_summary_em = pd.concat([mtab_snr_summary_em, tmp])
|
1194
1200
|
|
1201
|
+
mtab_snr_summary = mtab_snr_summary.sort_values(by=['MXSNR'], ascending=False)
|
1195
1202
|
# Correct percentages in snr summary table
|
1196
1203
|
mtab_snr_summary['PESQi%'] = 100 * (mtab_snr_summary['PESQ'] - mtab_snr_summary['MXPESQ']) / np.maximum(
|
1197
1204
|
mtab_snr_summary['MXPESQ'], 0.01)
|
@@ -1202,9 +1209,11 @@ def main():
|
|
1202
1209
|
else:
|
1203
1210
|
mtab_snr_summary['WERi%'].iloc[i] = -999.0
|
1204
1211
|
else:
|
1205
|
-
mtab_snr_summary['
|
1206
|
-
|
1207
|
-
|
1212
|
+
if ~np.isnan(mtab_snr_summary['WER'].iloc[i]) and ~np.isnan(mtab_snr_summary['MXWER'].iloc[i]):
|
1213
|
+
# update WERi% in 6th col
|
1214
|
+
mtab_snr_summary.iloc[i, 6] = 100 * (mtab_snr_summary['MXWER'].iloc[i] -
|
1215
|
+
mtab_snr_summary['WER'].iloc[i]) / \
|
1216
|
+
mtab_snr_summary['MXWER'].iloc[i]
|
1208
1217
|
|
1209
1218
|
# Calculate avg metrics over all mixtures except -99
|
1210
1219
|
all_mtab1_sorted_nom99 = all_mtab1_sorted[all_mtab1_sorted.MXSNR != -99]
|
sonusai/genft.py
CHANGED
@@ -165,7 +165,8 @@ def main() -> None:
|
|
165
165
|
logger.info(f'Wrote {len(mixids)} mixtures to {location}')
|
166
166
|
logger.info('')
|
167
167
|
logger.info(f'Duration: {seconds_to_hms(seconds=duration)}')
|
168
|
-
logger.info(
|
168
|
+
logger.info(
|
169
|
+
f'feature: {human_readable_size(total_feature_frames * mixdb.fg_stride * mixdb.feature_parameters * 4, 1)}')
|
169
170
|
logger.info(f'truth_f: {human_readable_size(total_feature_frames * mixdb.num_classes * 4, 1)}')
|
170
171
|
if compute_segsnr:
|
171
172
|
logger.info(f'segsnr: {human_readable_size(total_transform_frames * 4, 1)}')
|
sonusai/genmixdb.py
CHANGED
@@ -225,7 +225,7 @@ def genmixdb(location: str,
|
|
225
225
|
if logging:
|
226
226
|
logger.info('Collecting impulse responses')
|
227
227
|
|
228
|
-
impulse_response_files = get_impulse_response_files(config
|
228
|
+
impulse_response_files = get_impulse_response_files(config)
|
229
229
|
|
230
230
|
populate_impulse_response_file_table(location, impulse_response_files, test)
|
231
231
|
|
@@ -337,12 +337,12 @@ def genmixdb(location: str,
|
|
337
337
|
log_duration_and_sizes(total_duration=total_duration,
|
338
338
|
num_classes=mixdb.num_classes,
|
339
339
|
feature_step_samples=mixdb.feature_step_samples,
|
340
|
-
|
340
|
+
feature_parameters=mixdb.feature_parameters,
|
341
341
|
stride=mixdb.fg_stride,
|
342
342
|
desc='Estimated')
|
343
343
|
logger.info(f'Feature shape: '
|
344
|
-
f'{mixdb.fg_stride} x {mixdb.
|
345
|
-
f'({mixdb.fg_stride * mixdb.
|
344
|
+
f'{mixdb.fg_stride} x {mixdb.feature_parameters} '
|
345
|
+
f'({mixdb.fg_stride * mixdb.feature_parameters} total params)')
|
346
346
|
logger.info(f'Feature samples: {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)')
|
347
347
|
logger.info(f'Feature step samples: {mixdb.feature_step_samples} samples ({mixdb.feature_step_ms} ms)')
|
348
348
|
logger.info('')
|
@@ -371,7 +371,7 @@ def genmixdb(location: str,
|
|
371
371
|
log_duration_and_sizes(total_duration=total_duration,
|
372
372
|
num_classes=mixdb.num_classes,
|
373
373
|
feature_step_samples=mixdb.feature_step_samples,
|
374
|
-
|
374
|
+
feature_parameters=mixdb.feature_parameters,
|
375
375
|
stride=mixdb.fg_stride,
|
376
376
|
desc='Actual')
|
377
377
|
logger.info('')
|
sonusai/lsdb.py
CHANGED
@@ -48,8 +48,8 @@ def lsdb(mixdb: MixtureDatabase,
|
|
48
48
|
logger.info(f'{"Targets":{desc_len}} {mixdb.num_target_files}')
|
49
49
|
logger.info(f'{"Noises":{desc_len}} {mixdb.num_noise_files}')
|
50
50
|
logger.info(f'{"Feature":{desc_len}} {mixdb.feature}')
|
51
|
-
logger.info(f'{"Feature shape":{desc_len}} {mixdb.fg_stride} x {mixdb.
|
52
|
-
f'({mixdb.fg_stride * mixdb.
|
51
|
+
logger.info(f'{"Feature shape":{desc_len}} {mixdb.fg_stride} x {mixdb.feature_parameters} '
|
52
|
+
f'({mixdb.fg_stride * mixdb.feature_parameters} total params)')
|
53
53
|
logger.info(f'{"Feature samples":{desc_len}} {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)')
|
54
54
|
logger.info(f'{"Feature step samples":{desc_len}} {mixdb.feature_step_samples} samples '
|
55
55
|
f'({mixdb.feature_step_ms} ms)')
|
sonusai/main.py
CHANGED
@@ -3,91 +3,88 @@
|
|
3
3
|
usage: sonusai [--version] [--help] <command> [<args>...]
|
4
4
|
|
5
5
|
The sonusai commands are:
|
6
|
-
|
7
|
-
doc Documentation
|
8
|
-
evaluate Evaluate model performance
|
9
|
-
genft Generate feature and truth data
|
10
|
-
genmix Generate mixture and truth data
|
11
|
-
genmixdb Generate a mixture database
|
12
|
-
gentcst Generate target configuration from a subdirectory tree
|
13
|
-
keras_onnx Convert a trained Keras model to ONNX
|
14
|
-
keras_predict Run Keras predict on a trained model
|
15
|
-
keras_train Train a model using Keras
|
16
|
-
lsdb List information about a mixture database
|
17
|
-
mkmanifest Make ASR manifest JSON file
|
18
|
-
mkwav Make WAV files from a mixture database
|
19
|
-
onnx_predict Run ONNX predict on a trained model
|
20
|
-
plot Plot mixture data
|
21
|
-
post_spenh_targetf Run post-processing for speech enhancement targetf data
|
22
|
-
torchl_onnx Convert a trained Pytorch Lightning model to ONNX
|
23
|
-
torchl_predict Run Lightning predict on a trained model
|
24
|
-
torchl_train Train a model using Lightning
|
25
|
-
tplot Plot truth data
|
26
|
-
vars List custom SonusAI variables
|
6
|
+
<This information is automatically generated.>
|
27
7
|
|
28
8
|
Aaware Sound and Voice Machine Learning Framework. See 'sonusai help <command>'
|
29
9
|
for more information on a specific command.
|
30
10
|
|
31
11
|
"""
|
32
|
-
|
12
|
+
import signal
|
13
|
+
|
14
|
+
|
15
|
+
def signal_handler(_sig, _frame):
|
16
|
+
import sys
|
17
|
+
|
18
|
+
from sonusai import logger
|
19
|
+
|
20
|
+
logger.info('Canceled due to keyboard interrupt')
|
21
|
+
sys.exit(1)
|
22
|
+
|
23
|
+
|
24
|
+
signal.signal(signal.SIGINT, signal_handler)
|
33
25
|
|
34
26
|
|
35
27
|
def main() -> None:
|
28
|
+
from importlib import import_module
|
29
|
+
from pkgutil import iter_modules
|
30
|
+
|
31
|
+
from sonusai import commands_list
|
32
|
+
|
33
|
+
plugins = {}
|
34
|
+
plugin_docstrings = []
|
35
|
+
for _, name, _ in iter_modules():
|
36
|
+
if name.startswith('sonusai_') and not name.startswith('sonusai_asr_'):
|
37
|
+
module = import_module(name)
|
38
|
+
plugins[name] = {
|
39
|
+
'commands': commands_list(module.commands_doc),
|
40
|
+
'basedir': module.BASEDIR,
|
41
|
+
}
|
42
|
+
plugin_docstrings.append(module.commands_doc)
|
43
|
+
|
36
44
|
from docopt import docopt
|
37
45
|
|
38
|
-
import
|
46
|
+
from sonusai import __version__
|
47
|
+
from sonusai.utils import add_commands_to_docstring
|
39
48
|
from sonusai.utils import trim_docstring
|
40
49
|
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
'evaluate',
|
45
|
-
'genft',
|
46
|
-
'genmix',
|
47
|
-
'genmixdb',
|
48
|
-
'gentcst',
|
49
|
-
'keras_onnx',
|
50
|
-
'keras_predict',
|
51
|
-
'keras_train',
|
52
|
-
'lsdb',
|
53
|
-
'mkmanifest',
|
54
|
-
'mkwav',
|
55
|
-
'onnx_predict',
|
56
|
-
'plot',
|
57
|
-
'post_spenh_targetf',
|
58
|
-
'torchl_onnx',
|
59
|
-
'torchl_predict',
|
60
|
-
'torchl_train',
|
61
|
-
'tplot',
|
62
|
-
'vars',
|
63
|
-
)
|
64
|
-
|
65
|
-
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
50
|
+
args = docopt(trim_docstring(add_commands_to_docstring(__doc__, plugin_docstrings)),
|
51
|
+
version=__version__,
|
52
|
+
options_first=True)
|
66
53
|
|
67
54
|
command = args['<command>']
|
68
55
|
argv = args['<args>']
|
69
56
|
|
57
|
+
import sys
|
58
|
+
from os.path import join
|
70
59
|
from subprocess import call
|
71
60
|
|
72
61
|
import sonusai
|
73
|
-
from sonusai import
|
62
|
+
from sonusai import logger
|
74
63
|
|
64
|
+
base_commands = sonusai.commands_list()
|
75
65
|
if command == 'help':
|
76
66
|
if not argv:
|
77
67
|
exit(call(['sonusai', '-h']))
|
78
|
-
elif argv[0] in
|
79
|
-
exit(call(['python', f'{sonusai.BASEDIR
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
68
|
+
elif argv[0] in base_commands:
|
69
|
+
exit(call(['python', f'{join(sonusai.BASEDIR, argv[0])}.py', '-h']))
|
70
|
+
|
71
|
+
for plugin, data in plugins.items():
|
72
|
+
if argv[0] in data['commands']:
|
73
|
+
exit(call(['python', f'{join(data["basedir"], argv[0])}.py', '-h']))
|
74
|
+
|
75
|
+
logger.error(f"{argv[0]} is not a SonusAI command. See 'sonusai help'.")
|
76
|
+
sys.exit(1)
|
77
|
+
|
78
|
+
if command in base_commands:
|
79
|
+
exit(call(['python', f'{join(sonusai.BASEDIR, command)}.py'] + argv))
|
80
|
+
|
81
|
+
for plugin, data in plugins.items():
|
82
|
+
if command in data['commands']:
|
83
|
+
exit(call(['python', f'{join(data["basedir"], command)}.py'] + argv))
|
84
84
|
|
85
|
-
|
85
|
+
logger.error(f"{command} is not a SonusAI command. See 'sonusai help'.")
|
86
|
+
sys.exit(1)
|
86
87
|
|
87
88
|
|
88
89
|
if __name__ == '__main__':
|
89
|
-
|
90
|
-
main()
|
91
|
-
except KeyboardInterrupt:
|
92
|
-
logger.info('Canceled due to keyboard interrupt')
|
93
|
-
raise SystemExit(0)
|
90
|
+
main()
|
sonusai/mixture/__init__.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1
1
|
# SonusAI mixture utilities
|
2
|
-
from .audio import calculate_audio_from_transform
|
3
|
-
from .audio import calculate_transform_from_audio
|
4
2
|
from .audio import get_duration
|
5
3
|
from .audio import get_next_noise
|
6
4
|
from .audio import get_num_samples
|
@@ -83,6 +81,7 @@ from .datatypes import TruthFunctionConfig
|
|
83
81
|
from .datatypes import TruthSetting
|
84
82
|
from .datatypes import TruthSettings
|
85
83
|
from .datatypes import UniversalSNR
|
84
|
+
from .feature import get_audio_from_feature
|
86
85
|
from .feature import get_feature_from_audio
|
87
86
|
from .generation import generate_mixtures
|
88
87
|
from .generation import get_all_snrs_from_config
|
@@ -102,11 +101,14 @@ from .helpers import augmented_noise_samples
|
|
102
101
|
from .helpers import augmented_target_samples
|
103
102
|
from .helpers import check_audio_files_exist
|
104
103
|
from .helpers import forward_transform
|
104
|
+
from .helpers import get_audio_from_transform
|
105
105
|
from .helpers import get_ft
|
106
106
|
from .helpers import get_segsnr
|
107
|
+
from .helpers import get_transform_from_audio
|
107
108
|
from .helpers import get_truth_t
|
108
109
|
from .helpers import inverse_transform
|
109
110
|
from .helpers import mixture_metadata
|
111
|
+
from .helpers import read_mixture_data
|
110
112
|
from .helpers import write_mixture_data
|
111
113
|
from .helpers import write_mixture_metadata
|
112
114
|
from .log_duration_and_sizes import log_duration_and_sizes
|
sonusai/mixture/audio.py
CHANGED
@@ -1,11 +1,6 @@
|
|
1
1
|
from functools import lru_cache
|
2
2
|
|
3
|
-
from pyaaware import ForwardTransform
|
4
|
-
from pyaaware import InverseTransform
|
5
|
-
|
6
|
-
from sonusai.mixture.datatypes import AudioF
|
7
3
|
from sonusai.mixture.datatypes import AudioT
|
8
|
-
from sonusai.mixture.datatypes import EnergyT
|
9
4
|
from sonusai.mixture.datatypes import ImpulseResponseData
|
10
5
|
|
11
6
|
|
@@ -22,35 +17,6 @@ def get_next_noise(audio: AudioT, offset: int, length: int) -> AudioT:
|
|
22
17
|
return np.take(audio, range(offset, offset + length), mode='wrap')
|
23
18
|
|
24
19
|
|
25
|
-
def calculate_transform_from_audio(audio: AudioT,
|
26
|
-
transform: ForwardTransform) -> tuple[AudioF, EnergyT]:
|
27
|
-
"""Apply forward transform to input audio data to generate transform data
|
28
|
-
|
29
|
-
:param audio: Time domain data [samples]
|
30
|
-
:param transform: ForwardTransform object
|
31
|
-
:return: Frequency domain data [frames, bins], Energy [frames]
|
32
|
-
"""
|
33
|
-
f, e = transform.execute_all(audio)
|
34
|
-
return f.transpose(), e
|
35
|
-
|
36
|
-
|
37
|
-
def calculate_audio_from_transform(data: AudioF,
|
38
|
-
transform: InverseTransform,
|
39
|
-
trim: bool = True) -> tuple[AudioT, EnergyT]:
|
40
|
-
"""Apply inverse transform to input transform data to generate audio data
|
41
|
-
|
42
|
-
:param data: Frequency domain data [frames, bins]
|
43
|
-
:param transform: InverseTransform object
|
44
|
-
:param trim: Removes starting samples so output waveform will be time-aligned with input waveform to the transform
|
45
|
-
:return: Time domain data [samples], Energy [frames]
|
46
|
-
"""
|
47
|
-
t, e = transform.execute_all(data.transpose())
|
48
|
-
if trim:
|
49
|
-
t = t[transform.N - transform.R:]
|
50
|
-
|
51
|
-
return t, e
|
52
|
-
|
53
|
-
|
54
20
|
def get_duration(audio: AudioT) -> float:
|
55
21
|
"""Get duration of audio in seconds
|
56
22
|
|