sonusai 0.16.1__py3-none-any.whl → 0.17.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/onnx_predict.py CHANGED
@@ -1,19 +1,23 @@
1
- """sonusai predict
1
+ """sonusai onnx_predict
2
2
 
3
- usage: predict [-hvr] [-i MIXID] (-m MODEL) INPUT
3
+ usage: onnx_predict [-hvlwr] [--include GLOB] [-i MIXID] MODEL DATA ...
4
4
 
5
5
  options:
6
6
  -h, --help
7
7
  -v, --verbose Be verbose.
8
+ -l, --list-device-details List details of all OpenVINO available devices
8
9
  -i MIXID, --mixid MIXID Mixture ID(s) to generate if input is a mixture database. [default: *].
9
- -m MODEL, --model MODEL Trained ONNX model file.
10
+ --include GLOB Search only files whose base name matches GLOB. [default: *.{wav,flac}].
11
+ -w, --write-wav Calculate inverse transform of prediction and write .wav files
10
12
  -r, --reset Reset model between each file.
11
13
 
12
- Run prediction on a trained ONNX model using SonusAI genft or WAV data.
14
+ Run prediction (inference) using an onnx model on a SonusAI mixture dataset or audio files from a regex path.
15
+ The OnnxRuntime (ORT) inference engine is used to execute the inference.
13
16
 
14
17
  Inputs:
15
- MODEL A SonusAI trained ONNX model file.
16
- INPUT The input data must be one of the following:
18
+ MODEL ONNX model .onnx file of a trained model (weights are expected to be in the file).
19
+
20
+ DATA The input data must be one of the following:
17
21
  * WAV
18
22
  Using the given model, generate feature data and run prediction. A model file must be
19
23
  provided. The MIXID is ignored.
@@ -22,30 +26,74 @@ Inputs:
22
26
  Using the given SonusAI mixture database directory, generate feature and truth data if not found.
23
27
  Run prediction. The MIXID is required.
24
28
 
25
- Outputs the following to opredict-<TIMESTAMP> directory:
29
+
30
+ Note there are multiple ways to process model prediction over multiple audio data files:
31
+ 1. TSE (timestep single extension): mixture transform frames are fit into the timestep dimension and the model run as
32
+ a single inference call. If batch_size is > 1 then run multiple mixtures in one call with shorter mixtures
33
+ zero-padded to the size of the largest mixture.
34
+ 2. TME (timestep multi-extension): mixture is split into multiple timesteps, i.e. batch[0] is starting timesteps, ...
35
+ Note that batches are run independently, thus sequential state from one set of timesteps to the next will not be
36
+ maintained, thus results for such models (i.e. conv, LSTMs, in the tstep dimension) would not match using TSE mode.
37
+
38
+ TBD not sure below make sense, need to continue ??
39
+ 2. BSE (batch single extension): mixture transform frames are fit into the batch dimension. This make sense only if
40
+ independent predictions are made on each frame w/o considering previous frames (i.e.
41
+ timesteps=1 or there is no timestep dimension in the model (timesteps=0).
42
+ 3.classification
43
+
44
+ Outputs the following to ovpredict-<TIMESTAMP> directory:
26
45
  <id>.h5
27
46
  dataset: predict
28
47
  onnx_predict.log
29
48
 
30
49
  """
31
50
 
32
- import signal
33
-
34
- from sonusai.mixture import Feature
35
- from sonusai.mixture import Predict
36
- from sonusai.utils import SonusAIMetaData
37
-
38
-
39
- def signal_handler(_sig, _frame):
40
- import sys
51
+ from sonusai import logger
52
+ from typing import Any, List, Optional, Tuple
53
+ from os.path import basename, splitext, exists, isfile
54
+ import onnxruntime as ort
55
+ import onnx
56
+ from onnx import ValueInfoProto
57
+
58
+
59
+ def load_ort_session(model_path, providers=['CPUExecutionProvider']):
60
+ if exists(model_path) and isfile(model_path):
61
+ model_basename = basename(model_path)
62
+ model_root = splitext(model_basename)[0]
63
+ logger.info(f'Importing model from {model_basename}')
64
+ try:
65
+ session = ort.InferenceSession(model_path, providers=providers)
66
+ options = ort.SessionOptions()
67
+ except Exception as e:
68
+ logger.exception(f'Error: could not load onnx model from {model_path}: {e}')
69
+ raise SystemExit(1)
70
+ else:
71
+ logger.exception(f'Error: model file does not exist: {model_path}')
72
+ raise SystemExit(1)
41
73
 
42
- from sonusai import logger
74
+ logger.info(f'Opened session with provider options: {session._provider_options}.')
75
+ try:
76
+ meta = session.get_modelmeta()
77
+ hparams = eval(meta.custom_metadata_map["hparams"])
78
+ logger.info(f'Sonusai hyper-parameter metadata was found in model with {len(hparams)} parameters, '
79
+ f'checking for required ones ...')
80
+ # Print to log here will fail if required parameters not available.
81
+ logger.info(f'feature {hparams["feature"]}')
82
+ logger.info(f'batch_size {hparams["batch_size"]}')
83
+ logger.info(f'timesteps {hparams["timesteps"]}')
84
+ logger.info(f'flatten, add1ch {hparams["flatten"]}, {hparams["add1ch"]}')
85
+ logger.info(f'truth_mutex {hparams["truth_mutex"]}')
86
+ except:
87
+ hparams = None
88
+ logger.warning(f'Warning: onnx model does not have required SonusAI hyper-parameters.')
43
89
 
44
- logger.info('Canceled due to keyboard interrupt')
45
- sys.exit(1)
90
+ inputs = session.get_inputs()
91
+ outputs = session.get_outputs()
46
92
 
93
+ #in_names = [n.name for n in session.get_inputs()]
94
+ #out_names = [n.name for n in session.get_outputs()]
47
95
 
48
- signal.signal(signal.SIGINT, signal_handler)
96
+ return session, options, model_root, hparams, inputs, outputs
49
97
 
50
98
 
51
99
  def main() -> None:
@@ -57,194 +105,400 @@ def main() -> None:
57
105
  args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
58
106
 
59
107
  verbose = args['--verbose']
108
+ listdd = args['--list-device-details']
109
+ writewav= args['--write-wav']
60
110
  mixids = args['--mixid']
61
- model_name = args['--model']
62
111
  reset = args['--reset']
63
- input_name = args['INPUT']
112
+ include = args['--include']
113
+ model_path = args['MODEL']
114
+ datapaths = args['DATA']
64
115
 
65
- from os import makedirs
66
- from os.path import isdir
67
- from os.path import join
68
- from os.path import splitext
116
+ providers = ort.get_available_providers()
117
+ logger.info(f'Loaded Onnx runtime, available providers: {providers}.')
118
+
119
+ session, options, model_root, hparams, sess_inputs, sess_outputs = load_ort_session(model_path)
120
+ if hparams is None:
121
+ logger.error(f'Error: onnx model does not have required Sonusai hyper-parameters, can not proceed.')
122
+ raise SystemExit(1)
123
+ if len(sess_inputs) != 1:
124
+ logger.error(f'Error: onnx model does not have 1 input, but {len(sess_inputs)}. Exit due to unknown input.')
125
+
126
+ in0name = sess_inputs[0].name
127
+ in0type = sess_inputs[0].type
128
+ out0name = sess_outputs[0].name
129
+ out_names = [n.name for n in session.get_outputs()]
130
+
131
+ from os.path import join, dirname, isdir, normpath, realpath, abspath
132
+ from sonusai.utils.asr_manifest_functions import PathInfo
133
+ from sonusai.utils import braced_iglob
134
+ from sonusai.mixture import MixtureDatabase
69
135
 
70
- import h5py
71
- import onnxruntime as rt
72
- import numpy as np
136
+ mixdb_path = None
137
+ entries = None
138
+ if len(datapaths) == 1 and isdir(datapaths[0]): # Assume it's a single path to sonusai mixdb subdir
139
+ in_basename = basename(normpath(datapaths[0]))
140
+ mixdb_path= datapaths[0]
141
+ logger.debug(f'Attempting to load mixture database from {mixdb_path}')
142
+ mixdb = MixtureDatabase(mixdb_path)
143
+ logger.debug(f'Sonusai mixture db load success: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes')
144
+ p_mixids = mixdb.mixids_to_list(mixids)
145
+ if len(p_mixids) != mixdb.num_mixtures:
146
+ logger.info(f'Processing a subset of {p_mixids} from available mixtures.')
147
+ else: # search all datapaths for .wav, .flac (or whatever is specified in include)
148
+ in_basename = ''
149
+ entries: list[PathInfo] = []
150
+ for p in datapaths:
151
+ location = join(realpath(abspath(p)), '**', include)
152
+ logger.debug(f'Processing {location}')
153
+ for file in braced_iglob(pathname=location, recursive=True):
154
+ name = file
155
+ entries.append(PathInfo(abs_path=file, audio_filepath=name))
73
156
 
157
+ from sonusai.utils import create_ts_name
158
+ from os import makedirs
74
159
  from sonusai import create_file_handler
75
160
  from sonusai import initial_log_messages
76
- from sonusai import logger
77
161
  from sonusai import update_console_handler
78
- from sonusai.mixture import MixtureDatabase
79
- from sonusai.mixture import get_feature_from_audio
80
- from sonusai.mixture import read_audio
81
- from sonusai.utils import create_ts_name
82
- from sonusai.utils import get_frames_per_batch
83
- from sonusai.utils import get_sonusai_metadata
84
-
85
- output_dir = create_ts_name('opredict')
162
+ output_dir = create_ts_name('opredict-' + in_basename)
86
163
  makedirs(output_dir, exist_ok=True)
87
-
88
164
  # Setup logging file
89
- create_file_handler(join(output_dir, 'onnx_predict.log'))
165
+ create_file_handler(join(output_dir, 'onnx-predict.log'))
90
166
  update_console_handler(verbose)
91
167
  initial_log_messages('onnx_predict')
92
-
93
- model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
94
- model_metadata = get_sonusai_metadata(model)
95
-
96
- batch_size = model_metadata.input_shape[0]
97
- if model_metadata.timestep:
98
- timesteps = model_metadata.input_shape[1]
168
+ # Reprint some info messages
169
+ logger.info(f'Loaded OnnxRuntime, available providers: {providers}.')
170
+ logger.info(f'Read and compiled onnx model from {model_path}.')
171
+ if len(datapaths) == 1 and isdir(datapaths[0]): # Assume it's a single path to sonusai mixdb subdir
172
+ logger.info(f'Loaded mixture database from {datapaths}')
173
+ logger.info(f'Sonusai mixture db: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes')
99
174
  else:
100
- timesteps = 0
101
- num_classes = model_metadata.output_shape[-1]
102
-
103
- frames_per_batch = get_frames_per_batch(batch_size, timesteps)
104
-
105
- logger.info('')
106
- logger.info(f'feature {model_metadata.feature}')
107
- logger.info(f'num_classes {num_classes}')
108
- logger.info(f'batch_size {batch_size}')
109
- logger.info(f'timesteps {timesteps}')
110
- logger.info(f'flatten {model_metadata.flattened}')
111
- logger.info(f'add1ch {model_metadata.channel}')
112
- logger.info(f'truth_mutex {model_metadata.mutex}')
113
- logger.info(f'input_shape {model_metadata.input_shape}')
114
- logger.info(f'output_shape {model_metadata.output_shape}')
115
- logger.info('')
116
-
117
- if splitext(input_name)[1] == '.wav':
118
- # Convert WAV to feature data
119
- logger.info('')
120
- logger.info(f'Run prediction on {input_name}')
121
- audio = read_audio(input_name)
122
- feature = get_feature_from_audio(audio=audio, feature_mode=model_metadata.feature)
123
-
124
- predict = pad_and_predict(feature=feature,
125
- model_name=model_name,
126
- model_metadata=model_metadata,
127
- frames_per_batch=frames_per_batch,
128
- batch_size=batch_size,
129
- timesteps=timesteps,
130
- reset=reset)
131
-
132
- output_name = splitext(input_name)[0] + '.h5'
133
- with h5py.File(output_name, 'a') as f:
134
- if 'feature' in f:
135
- del f['feature']
136
- f.create_dataset(name='feature', data=feature)
137
-
138
- if 'predict' in f:
139
- del f['predict']
140
- f.create_dataset(name='predict', data=predict)
141
-
142
- logger.info(f'Saved results to {output_name}')
143
- return
144
-
145
- if not isdir(input_name):
146
- logger.exception(f'Do not know how to process input from {input_name}')
147
- raise SystemExit(1)
148
-
149
- mixdb = MixtureDatabase(input_name)
150
-
151
- if mixdb.feature != model_metadata.feature:
152
- logger.exception(f'Feature in mixture database does not match feature in model')
153
- raise SystemExit(1)
154
-
155
- mixids = mixdb.mixids_to_list(mixids)
156
- if reset:
157
- # reset mode cycles through each file one at a time
158
- for mixid in mixids:
159
- feature, _ = mixdb.mixture_ft(mixid)
160
-
161
- predict = pad_and_predict(feature=feature,
162
- model_name=model_name,
163
- model_metadata=model_metadata,
164
- frames_per_batch=frames_per_batch,
165
- batch_size=batch_size,
166
- timesteps=timesteps,
167
- reset=reset)
168
-
169
- output_name = join(output_dir, mixdb.mixtures[mixid].name)
170
- with h5py.File(output_name, 'a') as f:
171
- if 'predict' in f:
172
- del f['predict']
173
- f.create_dataset(name='predict', data=predict)
175
+ logger.info(f'{len(datapaths)} data paths specified, found {len(entries)} audio files.')
176
+ if in0type.find('float16') != -1:
177
+ model_is_fp16 = True
178
+ logger.info(f'Detected input of float16, converting all feature inputs to that type.')
174
179
  else:
175
- features: list[Feature] = []
176
- file_indices: list[slice] = []
177
- total_frames = 0
178
- for mixid in mixids:
179
- current_feature, _ = mixdb.mixture_ft(mixid)
180
- current_frames = current_feature.shape[0]
181
- features.append(current_feature)
182
- file_indices.append(slice(total_frames, total_frames + current_frames))
183
- total_frames += current_frames
184
- feature = np.vstack([features[i] for i in range(len(features))])
185
-
186
- predict = pad_and_predict(feature=feature,
187
- model_name=model_name,
188
- model_metadata=model_metadata,
189
- frames_per_batch=frames_per_batch,
190
- batch_size=batch_size,
191
- timesteps=timesteps,
192
- reset=reset)
193
-
194
- # Write data to separate files
195
- for idx, mixid in enumerate(mixids):
196
- output_name = join(output_dir, mixdb.mixtures[mixid].name)
197
- with h5py.File(output_name, 'a') as f:
198
- if 'predict' in f:
199
- del f['predict']
200
- f.create_dataset('predict', data=predict[file_indices[idx]])
201
-
202
- logger.info(f'Saved results to {output_dir}')
203
-
204
-
205
- def pad_and_predict(feature: Feature,
206
- model_name: str,
207
- model_metadata: SonusAIMetaData,
208
- frames_per_batch: int,
209
- batch_size: int,
210
- timesteps: int,
211
- reset: bool) -> Predict:
212
- import onnxruntime as rt
213
- import numpy as np
214
-
215
- from sonusai.utils import reshape_inputs
216
- from sonusai.utils import reshape_outputs
217
-
218
- frames = feature.shape[0]
219
- padding = frames_per_batch - frames % frames_per_batch
220
- feature = np.pad(array=feature, pad_width=((0, padding), (0, 0), (0, 0)))
221
- feature, _ = reshape_inputs(feature=feature,
222
- batch_size=batch_size,
223
- timesteps=timesteps,
224
- flatten=model_metadata.flattened,
225
- add1ch=model_metadata.channel)
226
- sequences = feature.shape[0] // model_metadata.input_shape[0]
227
- feature = np.reshape(feature, [sequences, *model_metadata.input_shape])
228
-
229
- model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
230
- output_names = [n.name for n in model.get_outputs()]
231
- input_names = [n.name for n in model.get_inputs()]
232
-
233
- predict = []
234
- for sequence in range(sequences):
235
- predict.append(model.run(output_names, {input_names[0]: feature[sequence]}))
236
- if reset:
237
- model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
238
-
239
- predict_arr = np.vstack(predict)
240
- # Combine [sequences, batch_size, ...] into [frames, ...]
241
- predict_shape = predict_arr.shape
242
- predict_arr = np.reshape(predict_arr, [predict_shape[0] * predict_shape[1], *predict_shape[2:]])
243
- predict_arr, _ = reshape_outputs(predict=predict_arr, timesteps=timesteps)
244
- predict_arr = predict_arr[:frames, :]
245
-
246
- return predict_arr
180
+ model_is_fp16 = False
181
+
182
+
183
+ if mixdb_path is not None: # mixdb input
184
+ # Assume (of course) that mixdb feature, etc. is what model expects
185
+ if hparams["feature"] != mixdb.feature:
186
+ logger.warning(f'Mixture feature does not match model feature, this inference run may fail.')
187
+ feature_mode = mixdb.feature # no choice, can't use hparams["feature"] since it's different than the mixdb
188
+
189
+ #if hparams["num_classes"] != mixdb.num_classes: # needs to be i.e. mixdb.feature_parameters
190
+ #if mixdb.num_classes != model_num_classes:
191
+ # logger.error(f'Feature parameters in mixture db {mixdb.num_classes} does not match num_classes in model {inp0shape[-1]}')
192
+ # raise SystemExit(1)
193
+
194
+ from sonusai.utils import reshape_inputs
195
+ from sonusai.utils import reshape_outputs
196
+ from sonusai.mixture import get_audio_from_feature
197
+ from sonusai.utils import write_wav
198
+ import numpy as np
199
+ import h5py
200
+ if hparams["batch_size"] == 1:
201
+ for mixid in p_mixids:
202
+ feature, _ = mixdb.mixture_ft(mixid) # frames x stride x feature_params
203
+ if hparams["timesteps"] == 0:
204
+ tsteps = 0 # no timestep dim, reshape will take care
205
+ else:
206
+ tsteps = feature.shape[0] # fit frames into timestep dimension (TSE mode)
207
+ feature, _ = reshape_inputs(feature=feature,
208
+ batch_size=1,
209
+ timesteps=tsteps,
210
+ flatten=hparams["flatten"],
211
+ add1ch=hparams["add1ch"])
212
+ if model_is_fp16:
213
+ feature = np.float16(feature)
214
+ # run inference, ort session wants i.e. batch x tsteps x feat_params, outputs numpy BxTxFP or BxFP
215
+ predict = session.run(out_names, {in0name: feature})[0]
216
+ #predict, _ = reshape_outputs(predict=predict[0], timesteps=frames) # frames x feat_params
217
+ output_fname = join(output_dir, mixdb.mixtures[mixid].name)
218
+ with h5py.File(output_fname, 'a') as f:
219
+ if 'predict' in f:
220
+ del f['predict']
221
+ f.create_dataset('predict', data=predict)
222
+ if writewav: # note only makes sense if model is predicting audio, i.e. tstep dimension exists
223
+ # predict_audio wants [frames, channels, feature_parameters] equiv. to tsteps, batch, bins
224
+ predict = np.transpose(predict, [1, 0, 2])
225
+ predict_audio = get_audio_from_feature(feature=predict, feature_mode=feature_mode)
226
+ owav_name = splitext(output_fname)[0]+'_predict.wav'
227
+ write_wav(owav_name, predict_audio)
228
+
229
+
230
+ #
231
+ # # sampler = None
232
+ # # p_datagen = TorchFromMixtureDatabase(mixdb=mixdb,
233
+ # # mixids=p_mixids,
234
+ # # batch_size=batch_size,
235
+ # # cut_len=0,
236
+ # # flatten=model.flatten,
237
+ # # add1ch=model.add1ch,
238
+ # # random_cut=False,
239
+ # # sampler=sampler,
240
+ # # drop_last=False,
241
+ # # num_workers=dlcpu)
242
+ #
243
+ # # Info needed to set up inverse transform
244
+ # half = model.num_classes // 2
245
+ # fg = FeatureGenerator(feature_mode=feature,
246
+ # num_classes=model.num_classes,
247
+ # truth_mutex=model.truth_mutex)
248
+ # itf = TorchInverseTransform(N=fg.itransform_N,
249
+ # R=fg.itransform_R,
250
+ # bin_start=fg.bin_start,
251
+ # bin_end=fg.bin_end,
252
+ # ttype=fg.itransform_ttype)
253
+ #
254
+ # enable_truth_wav = False
255
+ # enable_mix_wav = False
256
+ # if wavdbg:
257
+ # if mixdb.target_files[0].truth_settings[0].function == 'target_mixture_f':
258
+ # enable_mix_wav = True
259
+ # enable_truth_wav = True
260
+ # elif mixdb.target_files[0].truth_settings[0].function == 'target_f':
261
+ # enable_truth_wav = True
262
+ #
263
+ # if reset:
264
+ # logger.info(f'Running {mixdb.num_mixtures} mixtures individually with model reset ...')
265
+ # for idx, val in enumerate(p_datagen):
266
+ # # truth = val[1]
267
+ # feature = val[0]
268
+ # with torch.no_grad():
269
+ # ypred = model(feature)
270
+ # output_name = join(output_dir, mixdb.mixtures[idx].name)
271
+ # pdat = ypred.detach().numpy()
272
+ # if timesteps > 0:
273
+ # logger.debug(f'In and out tsteps: {feature.shape[1]},{pdat.shape[1]}')
274
+ # logger.debug(f'Writing predict shape {pdat.shape} to {output_name}')
275
+ # with h5py.File(output_name, 'a') as f:
276
+ # if 'predict' in f:
277
+ # del f['predict']
278
+ # f.create_dataset('predict', data=pdat)
279
+ #
280
+ # if wavdbg:
281
+ # owav_base = splitext(output_name)[0]
282
+ # tmp = torch.complex(ypred[..., :half], ypred[..., half:]).permute(2, 0, 1).detach()
283
+ # itf.reset()
284
+ # predwav, _ = itf.execute_all(tmp)
285
+ # # predwav, _ = calculate_audio_from_transform(tmp.numpy(), itf, trim=True)
286
+ # write_wav(owav_base + '.wav', predwav.permute([1, 0]).numpy(), 16000)
287
+ # if enable_truth_wav:
288
+ # # Note this support truth type target_f and target_mixture_f
289
+ # tmp = torch.complex(val[0][..., :half], val[0][..., half:2 * half]).permute(2, 0, 1).detach()
290
+ # itf.reset()
291
+ # truthwav, _ = itf.execute_all(tmp)
292
+ # write_wav(owav_base + '_truth.wav', truthwav.permute([1, 0]).numpy(), 16000)
293
+ #
294
+ # if enable_mix_wav:
295
+ # tmp = torch.complex(val[0][..., 2 * half:3 * half], val[0][..., 3 * half:]).permute(2, 0, 1)
296
+ # itf.reset()
297
+ # mixwav, _ = itf.execute_all(tmp.detach())
298
+ # write_wav(owav_base + '_mix.wav', mixwav.permute([1, 0]).numpy(), 16000)
299
+ #
300
+ #
301
+ #
302
+ #
303
+ #
304
+ #
305
+ #
306
+ #
307
+ #
308
+ #
309
+ #
310
+ #
311
+ #
312
+ #
313
+ #
314
+ #
315
+ #
316
+ # from os import makedirs
317
+ # from os.path import isdir
318
+ # from os.path import join
319
+ # from os.path import splitext
320
+ #
321
+ # import h5py
322
+ # import onnxruntime as rt
323
+ # import numpy as np
324
+ #
325
+ # from sonusai.mixture import Feature
326
+ # from sonusai.mixture import Predict
327
+ # from sonusai.utils import SonusAIMetaData
328
+ # from sonusai import create_file_handler
329
+ # from sonusai import initial_log_messages
330
+ # from sonusai import update_console_handler
331
+ # from sonusai.mixture import MixtureDatabase
332
+ # from sonusai.mixture import get_feature_from_audio
333
+ # from sonusai.mixture import read_audio
334
+ # from sonusai.utils import create_ts_name
335
+ # from sonusai.utils import get_frames_per_batch
336
+ # from sonusai.utils import get_sonusai_metadata
337
+ #
338
+ # output_dir = create_ts_name('ovpredict')
339
+ # makedirs(output_dir, exist_ok=True)
340
+ #
341
+ #
342
+ #
343
+ # model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
344
+ # model_metadata = get_sonusai_metadata(model)
345
+ #
346
+ # batch_size = model_metadata.input_shape[0]
347
+ # if model_metadata.timestep:
348
+ # timesteps = model_metadata.input_shape[1]
349
+ # else:
350
+ # timesteps = 0
351
+ # num_classes = model_metadata.output_shape[-1]
352
+ #
353
+ # frames_per_batch = get_frames_per_batch(batch_size, timesteps)
354
+ #
355
+ # logger.info('')
356
+ # logger.info(f'feature {model_metadata.feature}')
357
+ # logger.info(f'num_classes {num_classes}')
358
+ # logger.info(f'batch_size {batch_size}')
359
+ # logger.info(f'timesteps {timesteps}')
360
+ # logger.info(f'flatten {model_metadata.flattened}')
361
+ # logger.info(f'add1ch {model_metadata.channel}')
362
+ # logger.info(f'truth_mutex {model_metadata.mutex}')
363
+ # logger.info(f'input_shape {model_metadata.input_shape}')
364
+ # logger.info(f'output_shape {model_metadata.output_shape}')
365
+ # logger.info('')
366
+ #
367
+ # if splitext(entries)[1] == '.wav':
368
+ # # Convert WAV to feature data
369
+ # logger.info('')
370
+ # logger.info(f'Run prediction on {entries}')
371
+ # audio = read_audio()
372
+ # feature = get_feature_from_audio(audio=audio, feature_mode=model_metadata.feature)
373
+ #
374
+ # predict = pad_and_predict(feature=feature,
375
+ # model_name=model_name,
376
+ # model_metadata=model_metadata,
377
+ # frames_per_batch=frames_per_batch,
378
+ # batch_size=batch_size,
379
+ # timesteps=timesteps,
380
+ # reset=reset)
381
+ #
382
+ # output_name = splitext()[0] + '.h5'
383
+ # with h5py.File(output_name, 'a') as f:
384
+ # if 'feature' in f:
385
+ # del f['feature']
386
+ # f.create_dataset(name='feature', data=feature)
387
+ #
388
+ # if 'predict' in f:
389
+ # del f['predict']
390
+ # f.create_dataset(name='predict', data=predict)
391
+ #
392
+ # logger.info(f'Saved results to {output_name}')
393
+ # return
394
+ #
395
+ # if not isdir():
396
+ # logger.exception(f'Do not know how to process input from {entries}')
397
+ # raise SystemExit(1)
398
+ #
399
+ # mixdb = MixtureDatabase()
400
+ #
401
+ # if mixdb.feature != model_metadata.feature:
402
+ # logger.exception(f'Feature in mixture database does not match feature in model')
403
+ # raise SystemExit(1)
404
+ #
405
+ # mixids = mixdb.mixids_to_list(mixids)
406
+ # if reset:
407
+ # # reset mode cycles through each file one at a time
408
+ # for mixid in mixids:
409
+ # feature, _ = mixdb.mixture_ft(mixid)
410
+ #
411
+ # predict = pad_and_predict(feature=feature,
412
+ # model_name=model_name,
413
+ # model_metadata=model_metadata,
414
+ # frames_per_batch=frames_per_batch,
415
+ # batch_size=batch_size,
416
+ # timesteps=timesteps,
417
+ # reset=reset)
418
+ #
419
+ # output_name = join(output_dir, mixdb.mixtures[mixid].name)
420
+ # with h5py.File(output_name, 'a') as f:
421
+ # if 'predict' in f:
422
+ # del f['predict']
423
+ # f.create_dataset(name='predict', data=predict)
424
+ # else:
425
+ # features: list[Feature] = []
426
+ # file_indices: list[slice] = []
427
+ # total_frames = 0
428
+ # for mixid in mixids:
429
+ # current_feature, _ = mixdb.mixture_ft(mixid)
430
+ # current_frames = current_feature.shape[0]
431
+ # features.append(current_feature)
432
+ # file_indices.append(slice(total_frames, total_frames + current_frames))
433
+ # total_frames += current_frames
434
+ # feature = np.vstack([features[i] for i in range(len(features))])
435
+ #
436
+ # predict = pad_and_predict(feature=feature,
437
+ # model_name=model_name,
438
+ # model_metadata=model_metadata,
439
+ # frames_per_batch=frames_per_batch,
440
+ # batch_size=batch_size,
441
+ # timesteps=timesteps,
442
+ # reset=reset)
443
+ #
444
+ # # Write data to separate files
445
+ # for idx, mixid in enumerate(mixids):
446
+ # output_name = join(output_dir, mixdb.mixtures[mixid].name)
447
+ # with h5py.File(output_name, 'a') as f:
448
+ # if 'predict' in f:
449
+ # del f['predict']
450
+ # f.create_dataset('predict', data=predict[file_indices[idx]])
451
+ #
452
+ # logger.info(f'Saved results to {output_dir}')
453
+ #
454
+
455
+ # def pad_and_predict(feature: Feature,
456
+ # model_name: str,
457
+ # model_metadata: SonusAIMetaData,
458
+ # frames_per_batch: int,
459
+ # batch_size: int,
460
+ # timesteps: int,
461
+ # reset: bool) -> Predict:
462
+ # import onnxruntime as rt
463
+ # import numpy as np
464
+ #
465
+ # from sonusai.utils import reshape_inputs
466
+ # from sonusai.utils import reshape_outputs
467
+ #
468
+ # frames = feature.shape[0]
469
+ # padding = frames_per_batch - frames % frames_per_batch
470
+ # feature = np.pad(array=feature, pad_width=((0, padding), (0, 0), (0, 0)))
471
+ # feature, _ = reshape_inputs(feature=feature,
472
+ # batch_size=batch_size,
473
+ # timesteps=timesteps,
474
+ # flatten=model_metadata.flattened,
475
+ # add1ch=model_metadata.channel)
476
+ # sequences = feature.shape[0] // model_metadata.input_shape[0]
477
+ # feature = np.reshape(feature, [sequences, *model_metadata.input_shape])
478
+ #
479
+ # model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
480
+ # output_names = [n.name for n in model.get_outputs()]
481
+ # input_names = [n.name for n in model.get_inputs()]
482
+ #
483
+ # predict = []
484
+ # for sequence in range(sequences):
485
+ # predict.append(model.run(output_names, {input_names[0]: feature[sequence]}))
486
+ # if reset:
487
+ # model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
488
+ #
489
+ # predict_arr = np.vstack(predict)
490
+ # # Combine [sequences, batch_size, ...] into [frames, ...]
491
+ # predict_shape = predict_arr.shape
492
+ # predict_arr = np.reshape(predict_arr, [predict_shape[0] * predict_shape[1], *predict_shape[2:]])
493
+ # predict_arr, _ = reshape_outputs(predict=predict_arr, timesteps=timesteps)
494
+ # predict_arr = predict_arr[:frames, :]
495
+ #
496
+ # return predict_arr
247
497
 
248
498
 
249
499
  if __name__ == '__main__':
250
- main()
500
+ try:
501
+ main()
502
+ except KeyboardInterrupt:
503
+ logger.info('Canceled due to keyboard interrupt')
504
+ raise SystemExit(0)