sonusai 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/__init__.py CHANGED
@@ -19,6 +19,7 @@ commands_doc = """
19
19
  onnx_predict Run ONNX predict on a trained model
20
20
  plot Plot mixture data
21
21
  post_spenh_targetf Run post-processing for speech enhancement targetf data
22
+ summarize_metric_spenh Summarize speech enhancement and analysis results
22
23
  tplot Plot truth data
23
24
  vars List custom SonusAI variables
24
25
  """
sonusai/audiofe.py CHANGED
@@ -12,7 +12,7 @@ options:
12
12
  -m MODEL, --model MODEL PL model .py file path.
13
13
  -k CKPT, --checkpoint CKPT PL checkpoint file with weights.
14
14
  -a ASR, --asr ASR ASR method to use.
15
- -w WMODEL, --whisper WMODEL Whisper model used in aixplain_whisper and whisper methods. [default: tiny].
15
+ -w WMODEL, --whisper WMODEL Model used in whisper, aixplain_whisper and faster_whisper methods. [default: tiny].
16
16
 
17
17
  Aaware SonusAI Audio Front End.
18
18
 
@@ -24,47 +24,43 @@ audiofe_capture_<TIMESTAMP>.wav.
24
24
  If a model is specified, run prediction on audio data from this model. Then compute the inverse transform of the
25
25
  prediction result and save to audiofe_predict_<TIMESTAMP>.wav.
26
26
 
27
+ Also, if a model is specified, save plots of the capture data (time-domain signal and feature) to
28
+ audiofe_capture_<TIMESTAMP>.png and predict data (time-domain signal and feature) to
29
+ audiofe_predict_<TIMESTAMP>.png.
30
+
27
31
  If an ASR is specified, run ASR on the captured audio and print the results. In addition, if a model was also specified,
28
- run ASR on the predict audio and print the results.
32
+ run ASR on the predict audio and print the results. Examples: faster_whisper, google,
29
33
 
30
34
  If the debug option is enabled, write capture audio, feature, reconstruct audio, predict, and predict audio to
31
35
  audiofe_<TIMESTAMP>.h5.
32
36
 
33
37
  """
34
- from os.path import exists
35
- from select import select
36
- from sys import stdin
38
+ import signal
37
39
 
38
- import h5py
39
40
  import numpy as np
40
- import pyaudio
41
- import torch
42
- from docopt import docopt
43
- from docopt import printable_usage
44
-
45
- import sonusai
46
- from sonusai import create_file_handler
47
- from sonusai import initial_log_messages
48
- from sonusai import logger
49
- from sonusai import update_console_handler
41
+
50
42
  from sonusai.mixture import AudioT
51
- from sonusai.mixture import CHANNEL_COUNT
52
- from sonusai.mixture import SAMPLE_RATE
53
- from sonusai.mixture import get_audio_from_feature
54
- from sonusai.mixture import get_feature_from_audio
55
- from sonusai.mixture import read_audio
56
- from sonusai.utils import calc_asr
57
- from sonusai.utils import create_timestamp
58
- from sonusai.utils import get_input_device_index_by_name
59
- from sonusai.utils import get_input_devices
60
- from sonusai.utils import load_torchl_ckpt_model
61
- from sonusai.utils import trim_docstring
62
- from sonusai.utils import write_wav
43
+
44
+
45
+ def signal_handler(_sig, _frame):
46
+ import sys
47
+
48
+ from sonusai import logger
49
+
50
+ logger.info('Canceled due to keyboard interrupt')
51
+ sys.exit(1)
52
+
53
+
54
+ signal.signal(signal.SIGINT, signal_handler)
63
55
 
64
56
 
65
57
  def main() -> None:
58
+ from docopt import docopt
59
+
60
+ import sonusai
61
+ from sonusai.utils import trim_docstring
62
+
66
63
  args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
67
- ts = create_timestamp()
68
64
 
69
65
  verbose = args['--verbose']
70
66
  length = float(args['--length'])
@@ -76,13 +72,63 @@ def main() -> None:
76
72
  debug = args['--debug']
77
73
  show = args['--show']
78
74
 
79
- capture_name = f'audiofe_capture_{ts}.wav'
80
- predict_name = f'audiofe_predict_{ts}.wav'
75
+ from os.path import exists
76
+
77
+ import h5py
78
+ import pyaudio
79
+ import torch
80
+ from docopt import printable_usage
81
+ from sonusai_torchl.utils import load_torchl_ckpt_model
82
+ from sonusai.utils.onnx_utils import load_ort_session
83
+
84
+ from sonusai import create_file_handler
85
+ from sonusai import initial_log_messages
86
+ from sonusai import logger
87
+ from sonusai import update_console_handler
88
+ from sonusai.mixture import SAMPLE_RATE
89
+ from sonusai.mixture import get_audio_from_feature
90
+ from sonusai.mixture import get_feature_from_audio
91
+ from sonusai.utils import calc_asr
92
+ from sonusai.utils import create_timestamp
93
+ from sonusai.utils import get_input_devices
94
+ from sonusai.utils import trim_docstring
95
+ from sonusai.utils import write_wav
96
+
97
+ ts = create_timestamp()
98
+ capture_name = f'audiofe_capture_{ts}'
99
+ capture_wav = capture_name + '.wav'
100
+ capture_png = capture_name + '.png'
101
+ predict_name = f'audiofe_predict_{ts}'
102
+ predict_wav = predict_name + '.wav'
103
+ predict_png = predict_name + '.png'
81
104
  h5_name = f'audiofe_{ts}.h5'
82
105
 
83
- if model_name is not None and ckpt_name is None:
84
- print(printable_usage(trim_docstring(__doc__)))
85
- exit(1)
106
+ if model_name is not None:
107
+ from os.path import splitext
108
+ if splitext(model_name)[1] == '.onnx':
109
+ session, options, model_root, hparams, sess_inputs, sess_outputs = load_ort_session(model_name)
110
+ if hparams is None:
111
+ logger.error(f'Error: onnx model does not have required SonusAI hyper-parameters, can not proceed.')
112
+ raise SystemExit(1)
113
+ feature_mode = hparams["feature"]
114
+ model_is_onnx = True
115
+ in0name = sess_inputs[0].name
116
+ in0type = sess_inputs[0].type
117
+ out0name = sess_outputs[0].name
118
+ out_names = [n.name for n in session.get_outputs()]
119
+ if in0type.find('float16') != -1:
120
+ model_is_fp16 = True
121
+ logger.info(f'Detected input of float16, converting all feature inputs to that type.')
122
+ else:
123
+ model_is_fp16 = False
124
+ else:
125
+ model_is_onnx = False
126
+ if ckpt_name is None:
127
+ print(printable_usage(trim_docstring(__doc__)))
128
+ exit(1)
129
+ model = load_torchl_ckpt_model(model_name=model_name, ckpt_name=ckpt_name)
130
+ feature_mode = model.hparams.feature
131
+ model.eval()
86
132
 
87
133
  # Setup logging file
88
134
  create_file_handler('audiofe.log')
@@ -107,26 +153,28 @@ def main() -> None:
107
153
  except ValueError as e:
108
154
  logger.exception(e)
109
155
  return
156
+ # Only write if capture, not for file input
157
+ write_wav(capture_wav, capture_audio, SAMPLE_RATE)
158
+ logger.info('')
159
+ logger.info(f'Wrote capture audio with shape {capture_audio.shape} to {capture_wav}')
110
160
 
111
- write_wav(capture_name, capture_audio, SAMPLE_RATE)
112
- logger.info('')
113
- logger.info(f'Wrote capture audio with shape {capture_audio.shape} to {capture_name}')
114
161
  if debug:
115
162
  with h5py.File(h5_name, 'a') as f:
116
163
  if 'capture_audio' in f:
117
164
  del f['capture_audio']
118
165
  f.create_dataset('capture_audio', data=capture_audio)
119
- logger.info(f'Wrote capture audio with shape {capture_audio.shape} to {h5_name}')
166
+ logger.info(f'Wrote capture feature data with shape {capture_audio.shape} to {h5_name}')
120
167
 
121
168
  if asr_name is not None:
169
+ logger.info(f'Running ASR on captured audio with {asr_name} ...')
122
170
  capture_asr = calc_asr(capture_audio, engine=asr_name, whisper_model_name=whisper_name).text
123
171
  logger.info(f'Capture audio ASR: {capture_asr}')
124
172
 
125
173
  if model_name is not None:
126
- model = load_torchl_ckpt_model(model_name=model_name, ckpt_name=ckpt_name)
127
- model.eval()
174
+ feature = get_feature_from_audio(audio=capture_audio, feature_mode=feature_mode) #frames x stride x feat_params
175
+ save_figure(capture_png, capture_audio, feature)
176
+ logger.info(f'Wrote capture plots to {capture_png}')
128
177
 
129
- feature = get_feature_from_audio(audio=capture_audio, feature_mode=model.hparams.feature)
130
178
  if debug:
131
179
  with h5py.File(h5_name, 'a') as f:
132
180
  if 'feature' in f:
@@ -134,23 +182,20 @@ def main() -> None:
134
182
  f.create_dataset('feature', data=feature)
135
183
  logger.info(f'Wrote feature with shape {feature.shape} to {h5_name}')
136
184
 
137
- # if debug:
138
- # reconstruct_name = f'audiofe_reconstruct_{ts}.wav'
139
- # reconstruct_audio = get_audio_from_feature(feature=feature, feature_mode=model.hparams.feature)
140
- # samples = min(len(capture_audio), len(reconstruct_audio))
141
- # max_err = np.max(np.abs(capture_audio[:samples] - reconstruct_audio[:samples]))
142
- # logger.info(f'Maximum error between capture and reconstruct: {max_err}')
143
- # write_wav(reconstruct_name, reconstruct_audio, SAMPLE_RATE)
144
- # logger.info(f'Wrote reconstruct audio with shape {reconstruct_audio.shape} to {reconstruct_name}')
145
- # with h5py.File(h5_name, 'a') as f:
146
- # if 'reconstruct_audio' in f:
147
- # del f['reconstruct_audio']
148
- # f.create_dataset('reconstruct_audio', data=reconstruct_audio)
149
- # logger.info(f'Wrote reconstruct audio with shape {reconstruct_audio.shape} to {h5_name}')
150
-
151
- with torch.no_grad():
152
- # model wants batch x timesteps x feature_parameters
153
- predict = model(torch.tensor(feature).permute((1, 0, 2))).permute(1, 0, 2).numpy()
185
+ if model_is_onnx:
186
+ # run ort session, wants i.e. batch x tsteps x feat_params, outputs numpy BxTxFP or BxFP
187
+ # Note full reshape not needed here since we assume speech enhanement type model, so a transpose suffices
188
+ if model_is_fp16:
189
+ feature = np.float16(feature)
190
+ # run inference, ort session wants i.e. batch x tsteps x feat_params, outputs numpy BxTxFP or BxFP
191
+ predict = np.transpose(session.run(out_names, {in0name: np.transpose(feature,(1,0,2))})[0],(1,0,2))
192
+ else:
193
+ with torch.no_grad():
194
+ # model wants batch x timesteps x feature_parameters
195
+ predict = model(torch.tensor(feature).permute((1, 0, 2))).permute(1, 0, 2).numpy()
196
+
197
+
198
+
154
199
  if debug:
155
200
  with h5py.File(h5_name, 'a') as f:
156
201
  if 'predict' in f:
@@ -158,9 +203,9 @@ def main() -> None:
158
203
  f.create_dataset('predict', data=predict)
159
204
  logger.info(f'Wrote predict with shape {predict.shape} to {h5_name}')
160
205
 
161
- predict_audio = get_audio_from_feature(feature=predict, feature_mode=model.hparams.feature)
162
- write_wav(predict_name, predict_audio, SAMPLE_RATE)
163
- logger.info(f'Wrote predict audio with shape {predict_audio.shape} to {predict_name}')
206
+ predict_audio = get_audio_from_feature(feature=predict, feature_mode=feature_mode)
207
+ write_wav(predict_wav, predict_audio, SAMPLE_RATE)
208
+ logger.info(f'Wrote predict audio with shape {predict_audio.shape} to {predict_wav}')
164
209
  if debug:
165
210
  with h5py.File(h5_name, 'a') as f:
166
211
  if 'predict_audio' in f:
@@ -168,12 +213,27 @@ def main() -> None:
168
213
  f.create_dataset('predict_audio', data=predict_audio)
169
214
  logger.info(f'Wrote predict audio with shape {predict_audio.shape} to {h5_name}')
170
215
 
216
+ save_figure(predict_png, predict_audio, predict)
217
+ logger.info(f'Wrote predict plots to {predict_png}')
218
+
171
219
  if asr_name is not None:
220
+ logger.info(f'Running ASR on model-enhanced audio with {asr_name} ...')
172
221
  predict_asr = calc_asr(predict_audio, engine=asr_name, whisper_model_name=whisper_name).text
173
222
  logger.info(f'Predict audio ASR: {predict_asr}')
174
223
 
175
224
 
176
225
  def get_frames_from_device(input_name: str | None, length: float, chunk: int = 1024) -> AudioT:
226
+ from select import select
227
+ from sys import stdin
228
+
229
+ import pyaudio
230
+
231
+ from sonusai import logger
232
+ from sonusai.mixture import CHANNEL_COUNT
233
+ from sonusai.mixture import SAMPLE_RATE
234
+ from sonusai.utils import get_input_device_index_by_name
235
+ from sonusai.utils import get_input_devices
236
+
177
237
  p = pyaudio.PyAudio()
178
238
 
179
239
  input_devices = get_input_devices(p)
@@ -224,6 +284,10 @@ def get_frames_from_device(input_name: str | None, length: float, chunk: int = 1
224
284
 
225
285
 
226
286
  def get_frames_from_file(input_name: str, length: float) -> AudioT:
287
+ from sonusai import logger
288
+ from sonusai.mixture import SAMPLE_RATE
289
+ from sonusai.mixture import read_audio
290
+
227
291
  logger.info(f'Capturing from {input_name}')
228
292
  frames = read_audio(input_name)
229
293
  if length != -1:
@@ -233,5 +297,37 @@ def get_frames_from_file(input_name: str, length: float) -> AudioT:
233
297
  return frames
234
298
 
235
299
 
300
+ def save_figure(name: str, audio: np.ndarray, feature: np.ndarray) -> None:
301
+ import matplotlib.pyplot as plt
302
+ from scipy.interpolate import CubicSpline
303
+
304
+ from sonusai.mixture import SAMPLE_RATE
305
+ from sonusai.utils import unstack_complex
306
+
307
+ spectrum = 20 * np.log(np.abs(np.squeeze(unstack_complex(feature)).transpose()))
308
+ frames = spectrum.shape[1]
309
+ samples = (len(audio) // frames) * frames
310
+ length_in_s = samples / SAMPLE_RATE
311
+ interp = samples // frames
312
+
313
+ ts = np.arange(0.0, length_in_s, interp / SAMPLE_RATE)
314
+ t = np.arange(0.0, length_in_s, 1 / SAMPLE_RATE)
315
+
316
+ spectrum = CubicSpline(ts, spectrum, axis=-1)(t)
317
+
318
+ fig, (ax1, ax2) = plt.subplots(nrows=2)
319
+ ax1.set_title(name)
320
+ ax1.plot(t, audio[:samples])
321
+ ax1.set_ylabel('Signal')
322
+ ax1.set_xlim(0, length_in_s)
323
+ ax1.set_ylim(-1, 1)
324
+
325
+ ax2.imshow(spectrum, origin='lower', aspect='auto')
326
+ ax2.set_xticks([])
327
+ ax2.set_ylabel('Feature')
328
+
329
+ plt.savefig(name, dpi=300)
330
+
331
+
236
332
  if __name__ == '__main__':
237
333
  main()