sonusai 0.17.0__py3-none-any.whl → 0.17.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/audiofe.py +22 -51
- sonusai/calc_metric_spenh.py +206 -213
- sonusai/doc/doc.py +1 -1
- sonusai/mixture/__init__.py +2 -0
- sonusai/mixture/audio.py +12 -0
- sonusai/mixture/datatypes.py +11 -3
- sonusai/mixture/mixdb.py +101 -0
- sonusai/mixture/soundfile_audio.py +39 -0
- sonusai/mixture/speaker_metadata.py +35 -0
- sonusai/mixture/torchaudio_audio.py +22 -0
- sonusai/mkmanifest.py +1 -1
- sonusai/onnx_predict.py +114 -410
- sonusai/queries/queries.py +1 -1
- sonusai/speech/__init__.py +3 -0
- sonusai/speech/l2arctic.py +116 -0
- sonusai/speech/librispeech.py +99 -0
- sonusai/speech/mcgill.py +70 -0
- sonusai/speech/textgrid.py +100 -0
- sonusai/speech/timit.py +135 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +52 -0
- sonusai/speech/voxceleb2.py +86 -0
- sonusai/utils/__init__.py +2 -1
- sonusai/utils/asr_manifest_functions/__init__.py +0 -1
- sonusai/utils/asr_manifest_functions/data.py +0 -8
- sonusai/utils/asr_manifest_functions/librispeech.py +1 -1
- sonusai/utils/asr_manifest_functions/mcgill_speech.py +1 -1
- sonusai/utils/asr_manifest_functions/vctk_noisy_speech.py +1 -1
- sonusai/utils/braced_glob.py +7 -3
- sonusai/utils/onnx_utils.py +110 -106
- sonusai/utils/path_info.py +7 -0
- {sonusai-0.17.0.dist-info → sonusai-0.17.2.dist-info}/METADATA +2 -1
- {sonusai-0.17.0.dist-info → sonusai-0.17.2.dist-info}/RECORD +35 -30
- {sonusai-0.17.0.dist-info → sonusai-0.17.2.dist-info}/WHEEL +1 -1
- sonusai/calc_metric_spenh-save.py +0 -1334
- sonusai/onnx_predict-old.py +0 -240
- sonusai/onnx_predict-save.py +0 -487
- sonusai/ovino_predict.py +0 -508
- sonusai/ovino_query_devices.py +0 -47
- sonusai/torchl_onnx-old.py +0 -216
- {sonusai-0.17.0.dist-info → sonusai-0.17.2.dist-info}/entry_points.txt +0 -0
sonusai/onnx_predict.py
CHANGED
@@ -5,14 +5,12 @@ usage: onnx_predict [-hvlwr] [--include GLOB] [-i MIXID] MODEL DATA ...
|
|
5
5
|
options:
|
6
6
|
-h, --help
|
7
7
|
-v, --verbose Be verbose.
|
8
|
-
-l, --list-device-details List details of all OpenVINO available devices
|
9
8
|
-i MIXID, --mixid MIXID Mixture ID(s) to generate if input is a mixture database. [default: *].
|
10
9
|
--include GLOB Search only files whose base name matches GLOB. [default: *.{wav,flac}].
|
11
10
|
-w, --write-wav Calculate inverse transform of prediction and write .wav files
|
12
|
-
-r, --reset Reset model between each file.
|
13
11
|
|
14
|
-
Run prediction (inference) using an
|
15
|
-
The
|
12
|
+
Run prediction (inference) using an ONNX model on a SonusAI mixture dataset or audio files from a glob path.
|
13
|
+
The ONNX Runtime (ort) inference engine is used to execute the inference.
|
16
14
|
|
17
15
|
Inputs:
|
18
16
|
MODEL ONNX model .onnx file of a trained model (weights are expected to be in the file).
|
@@ -33,67 +31,34 @@ Note there are multiple ways to process model prediction over multiple audio dat
|
|
33
31
|
zero-padded to the size of the largest mixture.
|
34
32
|
2. TME (timestep multi-extension): mixture is split into multiple timesteps, i.e. batch[0] is starting timesteps, ...
|
35
33
|
Note that batches are run independently, thus sequential state from one set of timesteps to the next will not be
|
36
|
-
maintained, thus results for such models (i.e. conv, LSTMs, in the
|
34
|
+
maintained, thus results for such models (i.e. conv, LSTMs, in the timestep dimension) would not match using
|
35
|
+
TSE mode.
|
37
36
|
|
38
37
|
TBD not sure below make sense, need to continue ??
|
39
38
|
2. BSE (batch single extension): mixture transform frames are fit into the batch dimension. This make sense only if
|
40
|
-
independent predictions are made on each frame w/o considering previous frames (
|
41
|
-
|
42
|
-
3.
|
39
|
+
independent predictions are made on each frame w/o considering previous frames (timesteps=1) or there is no
|
40
|
+
timestep dimension in the model (timesteps=0).
|
41
|
+
3. Classification
|
43
42
|
|
44
|
-
Outputs the following to
|
43
|
+
Outputs the following to opredict-<TIMESTAMP> directory:
|
45
44
|
<id>.h5
|
46
45
|
dataset: predict
|
47
46
|
onnx_predict.log
|
48
47
|
|
49
48
|
"""
|
49
|
+
import signal
|
50
50
|
|
51
|
-
from sonusai import logger
|
52
|
-
from typing import Any, List, Optional, Tuple
|
53
|
-
from os.path import basename, splitext, exists, isfile
|
54
|
-
import onnxruntime as ort
|
55
|
-
import onnx
|
56
|
-
from onnx import ValueInfoProto
|
57
|
-
|
58
|
-
|
59
|
-
def load_ort_session(model_path, providers=['CPUExecutionProvider']):
|
60
|
-
if exists(model_path) and isfile(model_path):
|
61
|
-
model_basename = basename(model_path)
|
62
|
-
model_root = splitext(model_basename)[0]
|
63
|
-
logger.info(f'Importing model from {model_basename}')
|
64
|
-
try:
|
65
|
-
session = ort.InferenceSession(model_path, providers=providers)
|
66
|
-
options = ort.SessionOptions()
|
67
|
-
except Exception as e:
|
68
|
-
logger.exception(f'Error: could not load onnx model from {model_path}: {e}')
|
69
|
-
raise SystemExit(1)
|
70
|
-
else:
|
71
|
-
logger.exception(f'Error: model file does not exist: {model_path}')
|
72
|
-
raise SystemExit(1)
|
73
51
|
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
logger.info(f'Sonusai hyper-parameter metadata was found in model with {len(hparams)} parameters, '
|
79
|
-
f'checking for required ones ...')
|
80
|
-
# Print to log here will fail if required parameters not available.
|
81
|
-
logger.info(f'feature {hparams["feature"]}')
|
82
|
-
logger.info(f'batch_size {hparams["batch_size"]}')
|
83
|
-
logger.info(f'timesteps {hparams["timesteps"]}')
|
84
|
-
logger.info(f'flatten, add1ch {hparams["flatten"]}, {hparams["add1ch"]}')
|
85
|
-
logger.info(f'truth_mutex {hparams["truth_mutex"]}')
|
86
|
-
except:
|
87
|
-
hparams = None
|
88
|
-
logger.warning(f'Warning: onnx model does not have required SonusAI hyper-parameters.')
|
52
|
+
def signal_handler(_sig, _frame):
|
53
|
+
import sys
|
54
|
+
|
55
|
+
from sonusai import logger
|
89
56
|
|
90
|
-
|
91
|
-
|
57
|
+
logger.info('Canceled due to keyboard interrupt')
|
58
|
+
sys.exit(1)
|
92
59
|
|
93
|
-
#in_names = [n.name for n in session.get_inputs()]
|
94
|
-
#out_names = [n.name for n in session.get_outputs()]
|
95
60
|
|
96
|
-
|
61
|
+
signal.signal(signal.SIGINT, signal_handler)
|
97
62
|
|
98
63
|
|
99
64
|
def main() -> None:
|
@@ -105,400 +70,139 @@ def main() -> None:
|
|
105
70
|
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
106
71
|
|
107
72
|
verbose = args['--verbose']
|
108
|
-
|
109
|
-
writewav= args['--write-wav']
|
73
|
+
wav = args['--write-wav']
|
110
74
|
mixids = args['--mixid']
|
111
|
-
reset = args['--reset']
|
112
75
|
include = args['--include']
|
113
76
|
model_path = args['MODEL']
|
114
|
-
|
77
|
+
data_paths = args['DATA']
|
78
|
+
|
79
|
+
from os import makedirs
|
80
|
+
from os.path import abspath
|
81
|
+
from os.path import basename
|
82
|
+
from os.path import isdir
|
83
|
+
from os.path import join
|
84
|
+
from os.path import normpath
|
85
|
+
from os.path import realpath
|
86
|
+
from os.path import splitext
|
87
|
+
|
88
|
+
import h5py
|
89
|
+
import numpy as np
|
90
|
+
import onnxruntime as ort
|
91
|
+
|
92
|
+
from sonusai import create_file_handler
|
93
|
+
from sonusai import initial_log_messages
|
94
|
+
from sonusai import logger
|
95
|
+
from sonusai import update_console_handler
|
96
|
+
from sonusai.mixture import MixtureDatabase
|
97
|
+
from sonusai.mixture import get_audio_from_feature
|
98
|
+
from sonusai.utils import PathInfo
|
99
|
+
from sonusai.utils import braced_iglob
|
100
|
+
from sonusai.utils import create_ts_name
|
101
|
+
from sonusai.utils import load_ort_session
|
102
|
+
from sonusai.utils import reshape_inputs
|
103
|
+
from sonusai.utils import write_wav
|
104
|
+
|
105
|
+
mixdb_path = None
|
106
|
+
mixdb = None
|
107
|
+
p_mixids = None
|
108
|
+
entries: list[PathInfo] = []
|
109
|
+
|
110
|
+
if len(data_paths) == 1 and isdir(data_paths[0]):
|
111
|
+
# Assume it's a single path to SonusAI mixdb subdir
|
112
|
+
in_basename = basename(normpath(data_paths[0]))
|
113
|
+
mixdb_path = data_paths[0]
|
114
|
+
else:
|
115
|
+
# search all data paths for .wav, .flac (or whatever is specified in include)
|
116
|
+
in_basename = ''
|
117
|
+
|
118
|
+
output_dir = create_ts_name('opredict-' + in_basename)
|
119
|
+
makedirs(output_dir, exist_ok=True)
|
120
|
+
|
121
|
+
# Setup logging file
|
122
|
+
create_file_handler(join(output_dir, 'onnx-predict.log'))
|
123
|
+
update_console_handler(verbose)
|
124
|
+
initial_log_messages('onnx_predict')
|
115
125
|
|
116
126
|
providers = ort.get_available_providers()
|
117
|
-
logger.info(f'Loaded
|
127
|
+
logger.info(f'Loaded ONNX Runtime, available providers: {providers}.')
|
118
128
|
|
119
129
|
session, options, model_root, hparams, sess_inputs, sess_outputs = load_ort_session(model_path)
|
120
130
|
if hparams is None:
|
121
|
-
logger.error(f'Error:
|
131
|
+
logger.error(f'Error: ONNX model does not have required SonusAI hyperparameters, cannot proceed.')
|
122
132
|
raise SystemExit(1)
|
123
133
|
if len(sess_inputs) != 1:
|
124
|
-
logger.error(f'Error:
|
134
|
+
logger.error(f'Error: ONNX model does not have 1 input, but {len(sess_inputs)}. Exit due to unknown input.')
|
125
135
|
|
126
136
|
in0name = sess_inputs[0].name
|
127
137
|
in0type = sess_inputs[0].type
|
128
|
-
out0name = sess_outputs[0].name
|
129
138
|
out_names = [n.name for n in session.get_outputs()]
|
130
139
|
|
131
|
-
|
132
|
-
from sonusai.utils.asr_manifest_functions import PathInfo
|
133
|
-
from sonusai.utils import braced_iglob
|
134
|
-
from sonusai.mixture import MixtureDatabase
|
140
|
+
logger.info(f'Read and compiled ONNX model from {model_path}.')
|
135
141
|
|
136
|
-
mixdb_path
|
137
|
-
|
138
|
-
if len(datapaths) == 1 and isdir(datapaths[0]): # Assume it's a single path to sonusai mixdb subdir
|
139
|
-
in_basename = basename(normpath(datapaths[0]))
|
140
|
-
mixdb_path= datapaths[0]
|
142
|
+
if mixdb_path is not None:
|
143
|
+
# Assume it's a single path to SonusAI mixdb subdir
|
141
144
|
logger.debug(f'Attempting to load mixture database from {mixdb_path}')
|
142
145
|
mixdb = MixtureDatabase(mixdb_path)
|
143
|
-
logger.
|
146
|
+
logger.info(f'SonusAI mixdb: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes')
|
144
147
|
p_mixids = mixdb.mixids_to_list(mixids)
|
145
148
|
if len(p_mixids) != mixdb.num_mixtures:
|
146
149
|
logger.info(f'Processing a subset of {p_mixids} from available mixtures.')
|
147
|
-
else:
|
148
|
-
|
149
|
-
entries: list[PathInfo] = []
|
150
|
-
for p in datapaths:
|
150
|
+
else:
|
151
|
+
for p in data_paths:
|
151
152
|
location = join(realpath(abspath(p)), '**', include)
|
152
153
|
logger.debug(f'Processing {location}')
|
153
154
|
for file in braced_iglob(pathname=location, recursive=True):
|
154
155
|
name = file
|
155
156
|
entries.append(PathInfo(abs_path=file, audio_filepath=name))
|
157
|
+
logger.info(f'{len(data_paths)} data paths specified, found {len(entries)} audio files.')
|
156
158
|
|
157
|
-
from sonusai.utils import create_ts_name
|
158
|
-
from os import makedirs
|
159
|
-
from sonusai import create_file_handler
|
160
|
-
from sonusai import initial_log_messages
|
161
|
-
from sonusai import update_console_handler
|
162
|
-
output_dir = create_ts_name('opredict-' + in_basename)
|
163
|
-
makedirs(output_dir, exist_ok=True)
|
164
|
-
# Setup logging file
|
165
|
-
create_file_handler(join(output_dir, 'onnx-predict.log'))
|
166
|
-
update_console_handler(verbose)
|
167
|
-
initial_log_messages('onnx_predict')
|
168
|
-
# Reprint some info messages
|
169
|
-
logger.info(f'Loaded OnnxRuntime, available providers: {providers}.')
|
170
|
-
logger.info(f'Read and compiled onnx model from {model_path}.')
|
171
|
-
if len(datapaths) == 1 and isdir(datapaths[0]): # Assume it's a single path to sonusai mixdb subdir
|
172
|
-
logger.info(f'Loaded mixture database from {datapaths}')
|
173
|
-
logger.info(f'Sonusai mixture db: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes')
|
174
|
-
else:
|
175
|
-
logger.info(f'{len(datapaths)} data paths specified, found {len(entries)} audio files.')
|
176
159
|
if in0type.find('float16') != -1:
|
177
160
|
model_is_fp16 = True
|
178
161
|
logger.info(f'Detected input of float16, converting all feature inputs to that type.')
|
179
162
|
else:
|
180
163
|
model_is_fp16 = False
|
181
164
|
|
182
|
-
|
183
|
-
|
165
|
+
if mixdb_path is not None and hparams['batch_size'] == 1:
|
166
|
+
# mixdb input
|
184
167
|
# Assume (of course) that mixdb feature, etc. is what model expects
|
185
|
-
if hparams[
|
168
|
+
if hparams['feature'] != mixdb.feature:
|
186
169
|
logger.warning(f'Mixture feature does not match model feature, this inference run may fail.')
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
if
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
#
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
if writewav: # note only makes sense if model is predicting audio, i.e. tstep dimension exists
|
223
|
-
# predict_audio wants [frames, channels, feature_parameters] equiv. to tsteps, batch, bins
|
224
|
-
predict = np.transpose(predict, [1, 0, 2])
|
225
|
-
predict_audio = get_audio_from_feature(feature=predict, feature_mode=feature_mode)
|
226
|
-
owav_name = splitext(output_fname)[0]+'_predict.wav'
|
227
|
-
write_wav(owav_name, predict_audio)
|
228
|
-
|
229
|
-
|
230
|
-
#
|
231
|
-
# # sampler = None
|
232
|
-
# # p_datagen = TorchFromMixtureDatabase(mixdb=mixdb,
|
233
|
-
# # mixids=p_mixids,
|
234
|
-
# # batch_size=batch_size,
|
235
|
-
# # cut_len=0,
|
236
|
-
# # flatten=model.flatten,
|
237
|
-
# # add1ch=model.add1ch,
|
238
|
-
# # random_cut=False,
|
239
|
-
# # sampler=sampler,
|
240
|
-
# # drop_last=False,
|
241
|
-
# # num_workers=dlcpu)
|
242
|
-
#
|
243
|
-
# # Info needed to set up inverse transform
|
244
|
-
# half = model.num_classes // 2
|
245
|
-
# fg = FeatureGenerator(feature_mode=feature,
|
246
|
-
# num_classes=model.num_classes,
|
247
|
-
# truth_mutex=model.truth_mutex)
|
248
|
-
# itf = TorchInverseTransform(N=fg.itransform_N,
|
249
|
-
# R=fg.itransform_R,
|
250
|
-
# bin_start=fg.bin_start,
|
251
|
-
# bin_end=fg.bin_end,
|
252
|
-
# ttype=fg.itransform_ttype)
|
253
|
-
#
|
254
|
-
# enable_truth_wav = False
|
255
|
-
# enable_mix_wav = False
|
256
|
-
# if wavdbg:
|
257
|
-
# if mixdb.target_files[0].truth_settings[0].function == 'target_mixture_f':
|
258
|
-
# enable_mix_wav = True
|
259
|
-
# enable_truth_wav = True
|
260
|
-
# elif mixdb.target_files[0].truth_settings[0].function == 'target_f':
|
261
|
-
# enable_truth_wav = True
|
262
|
-
#
|
263
|
-
# if reset:
|
264
|
-
# logger.info(f'Running {mixdb.num_mixtures} mixtures individually with model reset ...')
|
265
|
-
# for idx, val in enumerate(p_datagen):
|
266
|
-
# # truth = val[1]
|
267
|
-
# feature = val[0]
|
268
|
-
# with torch.no_grad():
|
269
|
-
# ypred = model(feature)
|
270
|
-
# output_name = join(output_dir, mixdb.mixtures[idx].name)
|
271
|
-
# pdat = ypred.detach().numpy()
|
272
|
-
# if timesteps > 0:
|
273
|
-
# logger.debug(f'In and out tsteps: {feature.shape[1]},{pdat.shape[1]}')
|
274
|
-
# logger.debug(f'Writing predict shape {pdat.shape} to {output_name}')
|
275
|
-
# with h5py.File(output_name, 'a') as f:
|
276
|
-
# if 'predict' in f:
|
277
|
-
# del f['predict']
|
278
|
-
# f.create_dataset('predict', data=pdat)
|
279
|
-
#
|
280
|
-
# if wavdbg:
|
281
|
-
# owav_base = splitext(output_name)[0]
|
282
|
-
# tmp = torch.complex(ypred[..., :half], ypred[..., half:]).permute(2, 0, 1).detach()
|
283
|
-
# itf.reset()
|
284
|
-
# predwav, _ = itf.execute_all(tmp)
|
285
|
-
# # predwav, _ = calculate_audio_from_transform(tmp.numpy(), itf, trim=True)
|
286
|
-
# write_wav(owav_base + '.wav', predwav.permute([1, 0]).numpy(), 16000)
|
287
|
-
# if enable_truth_wav:
|
288
|
-
# # Note this support truth type target_f and target_mixture_f
|
289
|
-
# tmp = torch.complex(val[0][..., :half], val[0][..., half:2 * half]).permute(2, 0, 1).detach()
|
290
|
-
# itf.reset()
|
291
|
-
# truthwav, _ = itf.execute_all(tmp)
|
292
|
-
# write_wav(owav_base + '_truth.wav', truthwav.permute([1, 0]).numpy(), 16000)
|
293
|
-
#
|
294
|
-
# if enable_mix_wav:
|
295
|
-
# tmp = torch.complex(val[0][..., 2 * half:3 * half], val[0][..., 3 * half:]).permute(2, 0, 1)
|
296
|
-
# itf.reset()
|
297
|
-
# mixwav, _ = itf.execute_all(tmp.detach())
|
298
|
-
# write_wav(owav_base + '_mix.wav', mixwav.permute([1, 0]).numpy(), 16000)
|
299
|
-
#
|
300
|
-
#
|
301
|
-
#
|
302
|
-
#
|
303
|
-
#
|
304
|
-
#
|
305
|
-
#
|
306
|
-
#
|
307
|
-
#
|
308
|
-
#
|
309
|
-
#
|
310
|
-
#
|
311
|
-
#
|
312
|
-
#
|
313
|
-
#
|
314
|
-
#
|
315
|
-
#
|
316
|
-
# from os import makedirs
|
317
|
-
# from os.path import isdir
|
318
|
-
# from os.path import join
|
319
|
-
# from os.path import splitext
|
320
|
-
#
|
321
|
-
# import h5py
|
322
|
-
# import onnxruntime as rt
|
323
|
-
# import numpy as np
|
324
|
-
#
|
325
|
-
# from sonusai.mixture import Feature
|
326
|
-
# from sonusai.mixture import Predict
|
327
|
-
# from sonusai.utils import SonusAIMetaData
|
328
|
-
# from sonusai import create_file_handler
|
329
|
-
# from sonusai import initial_log_messages
|
330
|
-
# from sonusai import update_console_handler
|
331
|
-
# from sonusai.mixture import MixtureDatabase
|
332
|
-
# from sonusai.mixture import get_feature_from_audio
|
333
|
-
# from sonusai.mixture import read_audio
|
334
|
-
# from sonusai.utils import create_ts_name
|
335
|
-
# from sonusai.utils import get_frames_per_batch
|
336
|
-
# from sonusai.utils import get_sonusai_metadata
|
337
|
-
#
|
338
|
-
# output_dir = create_ts_name('ovpredict')
|
339
|
-
# makedirs(output_dir, exist_ok=True)
|
340
|
-
#
|
341
|
-
#
|
342
|
-
#
|
343
|
-
# model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
|
344
|
-
# model_metadata = get_sonusai_metadata(model)
|
345
|
-
#
|
346
|
-
# batch_size = model_metadata.input_shape[0]
|
347
|
-
# if model_metadata.timestep:
|
348
|
-
# timesteps = model_metadata.input_shape[1]
|
349
|
-
# else:
|
350
|
-
# timesteps = 0
|
351
|
-
# num_classes = model_metadata.output_shape[-1]
|
352
|
-
#
|
353
|
-
# frames_per_batch = get_frames_per_batch(batch_size, timesteps)
|
354
|
-
#
|
355
|
-
# logger.info('')
|
356
|
-
# logger.info(f'feature {model_metadata.feature}')
|
357
|
-
# logger.info(f'num_classes {num_classes}')
|
358
|
-
# logger.info(f'batch_size {batch_size}')
|
359
|
-
# logger.info(f'timesteps {timesteps}')
|
360
|
-
# logger.info(f'flatten {model_metadata.flattened}')
|
361
|
-
# logger.info(f'add1ch {model_metadata.channel}')
|
362
|
-
# logger.info(f'truth_mutex {model_metadata.mutex}')
|
363
|
-
# logger.info(f'input_shape {model_metadata.input_shape}')
|
364
|
-
# logger.info(f'output_shape {model_metadata.output_shape}')
|
365
|
-
# logger.info('')
|
366
|
-
#
|
367
|
-
# if splitext(entries)[1] == '.wav':
|
368
|
-
# # Convert WAV to feature data
|
369
|
-
# logger.info('')
|
370
|
-
# logger.info(f'Run prediction on {entries}')
|
371
|
-
# audio = read_audio()
|
372
|
-
# feature = get_feature_from_audio(audio=audio, feature_mode=model_metadata.feature)
|
373
|
-
#
|
374
|
-
# predict = pad_and_predict(feature=feature,
|
375
|
-
# model_name=model_name,
|
376
|
-
# model_metadata=model_metadata,
|
377
|
-
# frames_per_batch=frames_per_batch,
|
378
|
-
# batch_size=batch_size,
|
379
|
-
# timesteps=timesteps,
|
380
|
-
# reset=reset)
|
381
|
-
#
|
382
|
-
# output_name = splitext()[0] + '.h5'
|
383
|
-
# with h5py.File(output_name, 'a') as f:
|
384
|
-
# if 'feature' in f:
|
385
|
-
# del f['feature']
|
386
|
-
# f.create_dataset(name='feature', data=feature)
|
387
|
-
#
|
388
|
-
# if 'predict' in f:
|
389
|
-
# del f['predict']
|
390
|
-
# f.create_dataset(name='predict', data=predict)
|
391
|
-
#
|
392
|
-
# logger.info(f'Saved results to {output_name}')
|
393
|
-
# return
|
394
|
-
#
|
395
|
-
# if not isdir():
|
396
|
-
# logger.exception(f'Do not know how to process input from {entries}')
|
397
|
-
# raise SystemExit(1)
|
398
|
-
#
|
399
|
-
# mixdb = MixtureDatabase()
|
400
|
-
#
|
401
|
-
# if mixdb.feature != model_metadata.feature:
|
402
|
-
# logger.exception(f'Feature in mixture database does not match feature in model')
|
403
|
-
# raise SystemExit(1)
|
404
|
-
#
|
405
|
-
# mixids = mixdb.mixids_to_list(mixids)
|
406
|
-
# if reset:
|
407
|
-
# # reset mode cycles through each file one at a time
|
408
|
-
# for mixid in mixids:
|
409
|
-
# feature, _ = mixdb.mixture_ft(mixid)
|
410
|
-
#
|
411
|
-
# predict = pad_and_predict(feature=feature,
|
412
|
-
# model_name=model_name,
|
413
|
-
# model_metadata=model_metadata,
|
414
|
-
# frames_per_batch=frames_per_batch,
|
415
|
-
# batch_size=batch_size,
|
416
|
-
# timesteps=timesteps,
|
417
|
-
# reset=reset)
|
418
|
-
#
|
419
|
-
# output_name = join(output_dir, mixdb.mixtures[mixid].name)
|
420
|
-
# with h5py.File(output_name, 'a') as f:
|
421
|
-
# if 'predict' in f:
|
422
|
-
# del f['predict']
|
423
|
-
# f.create_dataset(name='predict', data=predict)
|
424
|
-
# else:
|
425
|
-
# features: list[Feature] = []
|
426
|
-
# file_indices: list[slice] = []
|
427
|
-
# total_frames = 0
|
428
|
-
# for mixid in mixids:
|
429
|
-
# current_feature, _ = mixdb.mixture_ft(mixid)
|
430
|
-
# current_frames = current_feature.shape[0]
|
431
|
-
# features.append(current_feature)
|
432
|
-
# file_indices.append(slice(total_frames, total_frames + current_frames))
|
433
|
-
# total_frames += current_frames
|
434
|
-
# feature = np.vstack([features[i] for i in range(len(features))])
|
435
|
-
#
|
436
|
-
# predict = pad_and_predict(feature=feature,
|
437
|
-
# model_name=model_name,
|
438
|
-
# model_metadata=model_metadata,
|
439
|
-
# frames_per_batch=frames_per_batch,
|
440
|
-
# batch_size=batch_size,
|
441
|
-
# timesteps=timesteps,
|
442
|
-
# reset=reset)
|
443
|
-
#
|
444
|
-
# # Write data to separate files
|
445
|
-
# for idx, mixid in enumerate(mixids):
|
446
|
-
# output_name = join(output_dir, mixdb.mixtures[mixid].name)
|
447
|
-
# with h5py.File(output_name, 'a') as f:
|
448
|
-
# if 'predict' in f:
|
449
|
-
# del f['predict']
|
450
|
-
# f.create_dataset('predict', data=predict[file_indices[idx]])
|
451
|
-
#
|
452
|
-
# logger.info(f'Saved results to {output_dir}')
|
453
|
-
#
|
454
|
-
|
455
|
-
# def pad_and_predict(feature: Feature,
|
456
|
-
# model_name: str,
|
457
|
-
# model_metadata: SonusAIMetaData,
|
458
|
-
# frames_per_batch: int,
|
459
|
-
# batch_size: int,
|
460
|
-
# timesteps: int,
|
461
|
-
# reset: bool) -> Predict:
|
462
|
-
# import onnxruntime as rt
|
463
|
-
# import numpy as np
|
464
|
-
#
|
465
|
-
# from sonusai.utils import reshape_inputs
|
466
|
-
# from sonusai.utils import reshape_outputs
|
467
|
-
#
|
468
|
-
# frames = feature.shape[0]
|
469
|
-
# padding = frames_per_batch - frames % frames_per_batch
|
470
|
-
# feature = np.pad(array=feature, pad_width=((0, padding), (0, 0), (0, 0)))
|
471
|
-
# feature, _ = reshape_inputs(feature=feature,
|
472
|
-
# batch_size=batch_size,
|
473
|
-
# timesteps=timesteps,
|
474
|
-
# flatten=model_metadata.flattened,
|
475
|
-
# add1ch=model_metadata.channel)
|
476
|
-
# sequences = feature.shape[0] // model_metadata.input_shape[0]
|
477
|
-
# feature = np.reshape(feature, [sequences, *model_metadata.input_shape])
|
478
|
-
#
|
479
|
-
# model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
|
480
|
-
# output_names = [n.name for n in model.get_outputs()]
|
481
|
-
# input_names = [n.name for n in model.get_inputs()]
|
482
|
-
#
|
483
|
-
# predict = []
|
484
|
-
# for sequence in range(sequences):
|
485
|
-
# predict.append(model.run(output_names, {input_names[0]: feature[sequence]}))
|
486
|
-
# if reset:
|
487
|
-
# model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
|
488
|
-
#
|
489
|
-
# predict_arr = np.vstack(predict)
|
490
|
-
# # Combine [sequences, batch_size, ...] into [frames, ...]
|
491
|
-
# predict_shape = predict_arr.shape
|
492
|
-
# predict_arr = np.reshape(predict_arr, [predict_shape[0] * predict_shape[1], *predict_shape[2:]])
|
493
|
-
# predict_arr, _ = reshape_outputs(predict=predict_arr, timesteps=timesteps)
|
494
|
-
# predict_arr = predict_arr[:frames, :]
|
495
|
-
#
|
496
|
-
# return predict_arr
|
170
|
+
# no choice, can't use hparams.feature since it's different from the mixdb
|
171
|
+
feature_mode = mixdb.feature
|
172
|
+
|
173
|
+
for mixid in p_mixids:
|
174
|
+
# frames x stride x feature_params
|
175
|
+
feature, _ = mixdb.mixture_ft(mixid)
|
176
|
+
if hparams['timesteps'] == 0:
|
177
|
+
# no timestep dimension, reshape will handle
|
178
|
+
timesteps = 0
|
179
|
+
else:
|
180
|
+
# fit frames into timestep dimension (TSE mode)
|
181
|
+
timesteps = feature.shape[0]
|
182
|
+
|
183
|
+
feature, _ = reshape_inputs(feature=feature,
|
184
|
+
batch_size=1,
|
185
|
+
timesteps=timesteps,
|
186
|
+
flatten=hparams['flatten'],
|
187
|
+
add1ch=hparams['add1ch'])
|
188
|
+
if model_is_fp16:
|
189
|
+
feature = np.float16(feature) # type: ignore
|
190
|
+
# run inference, ort session wants i.e. batch x timesteps x feat_params, outputs numpy BxTxFP or BxFP
|
191
|
+
predict = session.run(out_names, {in0name: feature})[0]
|
192
|
+
# predict, _ = reshape_outputs(predict=predict[0], timesteps=frames) # frames x feat_params
|
193
|
+
output_fname = join(output_dir, mixdb.mixtures[mixid].name)
|
194
|
+
with h5py.File(output_fname, 'a') as f:
|
195
|
+
if 'predict' in f:
|
196
|
+
del f['predict']
|
197
|
+
f.create_dataset('predict', data=predict)
|
198
|
+
if wav:
|
199
|
+
# note only makes sense if model is predicting audio, i.e., timestep dimension exists
|
200
|
+
# predict_audio wants [frames, channels, feature_parameters] equivalent to timesteps, batch, bins
|
201
|
+
predict = np.transpose(predict, [1, 0, 2])
|
202
|
+
predict_audio = get_audio_from_feature(feature=predict, feature_mode=feature_mode)
|
203
|
+
owav_name = splitext(output_fname)[0] + '_predict.wav'
|
204
|
+
write_wav(owav_name, predict_audio)
|
497
205
|
|
498
206
|
|
499
207
|
if __name__ == '__main__':
|
500
|
-
|
501
|
-
main()
|
502
|
-
except KeyboardInterrupt:
|
503
|
-
logger.info('Canceled due to keyboard interrupt')
|
504
|
-
raise SystemExit(0)
|
208
|
+
main()
|
sonusai/queries/queries.py
CHANGED