sonusai 0.16.1__py3-none-any.whl → 0.17.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/audiofe.py CHANGED
@@ -1,18 +1,17 @@
1
1
  """sonusai audiofe
2
2
 
3
- usage: audiofe [-hvds] [--version] [-i INPUT] [-l LENGTH] [-m MODEL] [-k CKPT] [-a ASR] [-w WMODEL]
3
+ usage: audiofe [-hvds] [--version] [-i INPUT] [-l LENGTH] [-m MODEL] [-a ASR] [-w WMODEL]
4
4
 
5
5
  options:
6
6
  -h, --help
7
7
  -v, --verbose Be verbose.
8
8
  -d, --debug Write debug data to H5 file.
9
- -s, --show Show a list of available audio inputs.
9
+ -s, --show Display a list of available audio inputs.
10
10
  -i INPUT, --input INPUT Input audio.
11
11
  -l LENGTH, --length LENGTH Length of audio in seconds. [default: -1].
12
- -m MODEL, --model MODEL PL model .py file path.
13
- -k CKPT, --checkpoint CKPT PL checkpoint file with weights.
12
+ -m MODEL, --model MODEL ONNX model.
14
13
  -a ASR, --asr ASR ASR method to use.
15
- -w WMODEL, --whisper WMODEL Whisper model used in aixplain_whisper and whisper methods. [default: tiny].
14
+ -w WMODEL, --whisper WMODEL Model used in whisper, aixplain_whisper and faster_whisper methods. [default: tiny].
16
15
 
17
16
  Aaware SonusAI Audio Front End.
18
17
 
@@ -29,7 +28,7 @@ audiofe_capture_<TIMESTAMP>.png and predict data (time-domain signal and feature
29
28
  audiofe_predict_<TIMESTAMP>.png.
30
29
 
31
30
  If an ASR is specified, run ASR on the captured audio and print the results. In addition, if a model was also specified,
32
- run ASR on the predict audio and print the results.
31
+ run ASR on the predict audio and print the results. Examples: faster_whisper, google,
33
32
 
34
33
  If the debug option is enabled, write capture audio, feature, reconstruct audio, predict, and predict audio to
35
34
  audiofe_<TIMESTAMP>.h5.
@@ -66,7 +65,6 @@ def main() -> None:
66
65
  length = float(args['--length'])
67
66
  input_name = args['--input']
68
67
  model_name = args['--model']
69
- ckpt_name = args['--checkpoint']
70
68
  asr_name = args['--asr']
71
69
  whisper_name = args['--whisper']
72
70
  debug = args['--debug']
@@ -76,9 +74,6 @@ def main() -> None:
76
74
 
77
75
  import h5py
78
76
  import pyaudio
79
- import torch
80
- from docopt import printable_usage
81
- from sonusai_torchl.utils import load_torchl_ckpt_model
82
77
 
83
78
  from sonusai import create_file_handler
84
79
  from sonusai import initial_log_messages
@@ -90,7 +85,7 @@ def main() -> None:
90
85
  from sonusai.utils import calc_asr
91
86
  from sonusai.utils import create_timestamp
92
87
  from sonusai.utils import get_input_devices
93
- from sonusai.utils import trim_docstring
88
+ from sonusai.utils import load_ort_session
94
89
  from sonusai.utils import write_wav
95
90
 
96
91
  ts = create_timestamp()
@@ -102,10 +97,6 @@ def main() -> None:
102
97
  predict_png = predict_name + '.png'
103
98
  h5_name = f'audiofe_{ts}.h5'
104
99
 
105
- if model_name is not None and ckpt_name is None:
106
- print(printable_usage(trim_docstring(__doc__)))
107
- exit(1)
108
-
109
100
  # Setup logging file
110
101
  create_file_handler('audiofe.log')
111
102
  update_console_handler(verbose)
@@ -129,26 +120,35 @@ def main() -> None:
129
120
  except ValueError as e:
130
121
  logger.exception(e)
131
122
  return
123
+ # Only write if capture from device, not for file input
124
+ write_wav(capture_wav, capture_audio, SAMPLE_RATE)
125
+ logger.info('')
126
+ logger.info(f'Wrote capture audio with shape {capture_audio.shape} to {capture_wav}')
132
127
 
133
- write_wav(capture_wav, capture_audio, SAMPLE_RATE)
134
- logger.info('')
135
- logger.info(f'Wrote capture audio with shape {capture_audio.shape} to {capture_wav}')
136
128
  if debug:
137
129
  with h5py.File(h5_name, 'a') as f:
138
130
  if 'capture_audio' in f:
139
131
  del f['capture_audio']
140
132
  f.create_dataset('capture_audio', data=capture_audio)
141
- logger.info(f'Wrote capture audio with shape {capture_audio.shape} to {h5_name}')
133
+ logger.info(f'Wrote capture feature data with shape {capture_audio.shape} to {h5_name}')
142
134
 
143
135
  if asr_name is not None:
136
+ logger.info(f'Running ASR on captured audio with {asr_name} ...')
144
137
  capture_asr = calc_asr(capture_audio, engine=asr_name, whisper_model_name=whisper_name).text
145
138
  logger.info(f'Capture audio ASR: {capture_asr}')
146
139
 
147
140
  if model_name is not None:
148
- model = load_torchl_ckpt_model(model_name=model_name, ckpt_name=ckpt_name)
149
- model.eval()
150
-
151
- feature = get_feature_from_audio(audio=capture_audio, feature_mode=model.hparams.feature)
141
+ session, options, model_root, hparams, sess_inputs, sess_outputs = load_ort_session(model_name)
142
+ if hparams is None:
143
+ logger.error(f'Error: ONNX model does not have required SonusAI hyperparameters, cannot proceed.')
144
+ raise SystemExit(1)
145
+ feature_mode = hparams.feature
146
+ in0name = sess_inputs[0].name
147
+ in0type = sess_inputs[0].type
148
+ out_names = [n.name for n in session.get_outputs()]
149
+
150
+ # frames x stride x feat_params
151
+ feature = get_feature_from_audio(audio=capture_audio, feature_mode=feature_mode)
152
152
  save_figure(capture_png, capture_audio, feature)
153
153
  logger.info(f'Wrote capture plots to {capture_png}')
154
154
 
@@ -159,9 +159,14 @@ def main() -> None:
159
159
  f.create_dataset('feature', data=feature)
160
160
  logger.info(f'Wrote feature with shape {feature.shape} to {h5_name}')
161
161
 
162
- with torch.no_grad():
163
- # model wants batch x timesteps x feature_parameters
164
- predict = model(torch.tensor(feature).permute((1, 0, 2))).permute(1, 0, 2).numpy()
162
+ if in0type.find('float16') != -1:
163
+ logger.info(f'Detected input of float16, converting all feature inputs to that type.')
164
+ feature = np.float16(feature) # type: ignore
165
+
166
+ # Run inference, ort session wants batch x timesteps x feat_params, outputs numpy BxTxFP or BxFP
167
+ # Note full reshape not needed here since we assume speech enhancement type model, so a transpose suffices
168
+ predict = np.transpose(session.run(out_names, {in0name: np.transpose(feature, (1, 0, 2))})[0], (1, 0, 2))
169
+
165
170
  if debug:
166
171
  with h5py.File(h5_name, 'a') as f:
167
172
  if 'predict' in f:
@@ -169,7 +174,7 @@ def main() -> None:
169
174
  f.create_dataset('predict', data=predict)
170
175
  logger.info(f'Wrote predict with shape {predict.shape} to {h5_name}')
171
176
 
172
- predict_audio = get_audio_from_feature(feature=predict, feature_mode=model.hparams.feature)
177
+ predict_audio = get_audio_from_feature(feature=predict, feature_mode=feature_mode)
173
178
  write_wav(predict_wav, predict_audio, SAMPLE_RATE)
174
179
  logger.info(f'Wrote predict audio with shape {predict_audio.shape} to {predict_wav}')
175
180
  if debug:
@@ -183,6 +188,7 @@ def main() -> None:
183
188
  logger.info(f'Wrote predict plots to {predict_png}')
184
189
 
185
190
  if asr_name is not None:
191
+ logger.info(f'Running ASR on model-enhanced audio with {asr_name} ...')
186
192
  predict_asr = calc_asr(predict_audio, engine=asr_name, whisper_model_name=whisper_name).text
187
193
  logger.info(f'Predict audio ASR: {predict_asr}')
188
194