sonusai 0.17.0__py3-none-any.whl → 0.17.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. sonusai/audiofe.py +22 -51
  2. sonusai/calc_metric_spenh.py +206 -213
  3. sonusai/doc/doc.py +1 -1
  4. sonusai/mixture/__init__.py +2 -0
  5. sonusai/mixture/audio.py +12 -0
  6. sonusai/mixture/datatypes.py +11 -3
  7. sonusai/mixture/mixdb.py +101 -0
  8. sonusai/mixture/soundfile_audio.py +39 -0
  9. sonusai/mixture/speaker_metadata.py +35 -0
  10. sonusai/mixture/torchaudio_audio.py +22 -0
  11. sonusai/mkmanifest.py +1 -1
  12. sonusai/onnx_predict.py +114 -410
  13. sonusai/queries/queries.py +1 -1
  14. sonusai/speech/__init__.py +3 -0
  15. sonusai/speech/l2arctic.py +116 -0
  16. sonusai/speech/librispeech.py +99 -0
  17. sonusai/speech/mcgill.py +70 -0
  18. sonusai/speech/textgrid.py +100 -0
  19. sonusai/speech/timit.py +135 -0
  20. sonusai/speech/types.py +12 -0
  21. sonusai/speech/vctk.py +52 -0
  22. sonusai/speech/voxceleb2.py +86 -0
  23. sonusai/utils/__init__.py +2 -1
  24. sonusai/utils/asr_manifest_functions/__init__.py +0 -1
  25. sonusai/utils/asr_manifest_functions/data.py +0 -8
  26. sonusai/utils/asr_manifest_functions/librispeech.py +1 -1
  27. sonusai/utils/asr_manifest_functions/mcgill_speech.py +1 -1
  28. sonusai/utils/asr_manifest_functions/vctk_noisy_speech.py +1 -1
  29. sonusai/utils/braced_glob.py +7 -3
  30. sonusai/utils/onnx_utils.py +110 -106
  31. sonusai/utils/path_info.py +7 -0
  32. {sonusai-0.17.0.dist-info → sonusai-0.17.2.dist-info}/METADATA +2 -1
  33. {sonusai-0.17.0.dist-info → sonusai-0.17.2.dist-info}/RECORD +35 -30
  34. {sonusai-0.17.0.dist-info → sonusai-0.17.2.dist-info}/WHEEL +1 -1
  35. sonusai/calc_metric_spenh-save.py +0 -1334
  36. sonusai/onnx_predict-old.py +0 -240
  37. sonusai/onnx_predict-save.py +0 -487
  38. sonusai/ovino_predict.py +0 -508
  39. sonusai/ovino_query_devices.py +0 -47
  40. sonusai/torchl_onnx-old.py +0 -216
  41. {sonusai-0.17.0.dist-info → sonusai-0.17.2.dist-info}/entry_points.txt +0 -0
sonusai/onnx_predict.py CHANGED
@@ -5,14 +5,12 @@ usage: onnx_predict [-hvlwr] [--include GLOB] [-i MIXID] MODEL DATA ...
5
5
  options:
6
6
  -h, --help
7
7
  -v, --verbose Be verbose.
8
- -l, --list-device-details List details of all OpenVINO available devices
9
8
  -i MIXID, --mixid MIXID Mixture ID(s) to generate if input is a mixture database. [default: *].
10
9
  --include GLOB Search only files whose base name matches GLOB. [default: *.{wav,flac}].
11
10
  -w, --write-wav Calculate inverse transform of prediction and write .wav files
12
- -r, --reset Reset model between each file.
13
11
 
14
- Run prediction (inference) using an onnx model on a SonusAI mixture dataset or audio files from a regex path.
15
- The OnnxRuntime (ORT) inference engine is used to execute the inference.
12
+ Run prediction (inference) using an ONNX model on a SonusAI mixture dataset or audio files from a glob path.
13
+ The ONNX Runtime (ort) inference engine is used to execute the inference.
16
14
 
17
15
  Inputs:
18
16
  MODEL ONNX model .onnx file of a trained model (weights are expected to be in the file).
@@ -33,67 +31,34 @@ Note there are multiple ways to process model prediction over multiple audio dat
33
31
  zero-padded to the size of the largest mixture.
34
32
  2. TME (timestep multi-extension): mixture is split into multiple timesteps, i.e. batch[0] is starting timesteps, ...
35
33
  Note that batches are run independently, thus sequential state from one set of timesteps to the next will not be
36
- maintained, thus results for such models (i.e. conv, LSTMs, in the tstep dimension) would not match using TSE mode.
34
+ maintained, thus results for such models (i.e. conv, LSTMs, in the timestep dimension) would not match using
35
+ TSE mode.
37
36
 
38
37
  TBD not sure below make sense, need to continue ??
39
38
  2. BSE (batch single extension): mixture transform frames are fit into the batch dimension. This make sense only if
40
- independent predictions are made on each frame w/o considering previous frames (i.e.
41
- timesteps=1 or there is no timestep dimension in the model (timesteps=0).
42
- 3.classification
39
+ independent predictions are made on each frame w/o considering previous frames (timesteps=1) or there is no
40
+ timestep dimension in the model (timesteps=0).
41
+ 3. Classification
43
42
 
44
- Outputs the following to ovpredict-<TIMESTAMP> directory:
43
+ Outputs the following to opredict-<TIMESTAMP> directory:
45
44
  <id>.h5
46
45
  dataset: predict
47
46
  onnx_predict.log
48
47
 
49
48
  """
49
+ import signal
50
50
 
51
- from sonusai import logger
52
- from typing import Any, List, Optional, Tuple
53
- from os.path import basename, splitext, exists, isfile
54
- import onnxruntime as ort
55
- import onnx
56
- from onnx import ValueInfoProto
57
-
58
-
59
- def load_ort_session(model_path, providers=['CPUExecutionProvider']):
60
- if exists(model_path) and isfile(model_path):
61
- model_basename = basename(model_path)
62
- model_root = splitext(model_basename)[0]
63
- logger.info(f'Importing model from {model_basename}')
64
- try:
65
- session = ort.InferenceSession(model_path, providers=providers)
66
- options = ort.SessionOptions()
67
- except Exception as e:
68
- logger.exception(f'Error: could not load onnx model from {model_path}: {e}')
69
- raise SystemExit(1)
70
- else:
71
- logger.exception(f'Error: model file does not exist: {model_path}')
72
- raise SystemExit(1)
73
51
 
74
- logger.info(f'Opened session with provider options: {session._provider_options}.')
75
- try:
76
- meta = session.get_modelmeta()
77
- hparams = eval(meta.custom_metadata_map["hparams"])
78
- logger.info(f'Sonusai hyper-parameter metadata was found in model with {len(hparams)} parameters, '
79
- f'checking for required ones ...')
80
- # Print to log here will fail if required parameters not available.
81
- logger.info(f'feature {hparams["feature"]}')
82
- logger.info(f'batch_size {hparams["batch_size"]}')
83
- logger.info(f'timesteps {hparams["timesteps"]}')
84
- logger.info(f'flatten, add1ch {hparams["flatten"]}, {hparams["add1ch"]}')
85
- logger.info(f'truth_mutex {hparams["truth_mutex"]}')
86
- except:
87
- hparams = None
88
- logger.warning(f'Warning: onnx model does not have required SonusAI hyper-parameters.')
52
+ def signal_handler(_sig, _frame):
53
+ import sys
54
+
55
+ from sonusai import logger
89
56
 
90
- inputs = session.get_inputs()
91
- outputs = session.get_outputs()
57
+ logger.info('Canceled due to keyboard interrupt')
58
+ sys.exit(1)
92
59
 
93
- #in_names = [n.name for n in session.get_inputs()]
94
- #out_names = [n.name for n in session.get_outputs()]
95
60
 
96
- return session, options, model_root, hparams, inputs, outputs
61
+ signal.signal(signal.SIGINT, signal_handler)
97
62
 
98
63
 
99
64
  def main() -> None:
@@ -105,400 +70,139 @@ def main() -> None:
105
70
  args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
106
71
 
107
72
  verbose = args['--verbose']
108
- listdd = args['--list-device-details']
109
- writewav= args['--write-wav']
73
+ wav = args['--write-wav']
110
74
  mixids = args['--mixid']
111
- reset = args['--reset']
112
75
  include = args['--include']
113
76
  model_path = args['MODEL']
114
- datapaths = args['DATA']
77
+ data_paths = args['DATA']
78
+
79
+ from os import makedirs
80
+ from os.path import abspath
81
+ from os.path import basename
82
+ from os.path import isdir
83
+ from os.path import join
84
+ from os.path import normpath
85
+ from os.path import realpath
86
+ from os.path import splitext
87
+
88
+ import h5py
89
+ import numpy as np
90
+ import onnxruntime as ort
91
+
92
+ from sonusai import create_file_handler
93
+ from sonusai import initial_log_messages
94
+ from sonusai import logger
95
+ from sonusai import update_console_handler
96
+ from sonusai.mixture import MixtureDatabase
97
+ from sonusai.mixture import get_audio_from_feature
98
+ from sonusai.utils import PathInfo
99
+ from sonusai.utils import braced_iglob
100
+ from sonusai.utils import create_ts_name
101
+ from sonusai.utils import load_ort_session
102
+ from sonusai.utils import reshape_inputs
103
+ from sonusai.utils import write_wav
104
+
105
+ mixdb_path = None
106
+ mixdb = None
107
+ p_mixids = None
108
+ entries: list[PathInfo] = []
109
+
110
+ if len(data_paths) == 1 and isdir(data_paths[0]):
111
+ # Assume it's a single path to SonusAI mixdb subdir
112
+ in_basename = basename(normpath(data_paths[0]))
113
+ mixdb_path = data_paths[0]
114
+ else:
115
+ # search all data paths for .wav, .flac (or whatever is specified in include)
116
+ in_basename = ''
117
+
118
+ output_dir = create_ts_name('opredict-' + in_basename)
119
+ makedirs(output_dir, exist_ok=True)
120
+
121
+ # Setup logging file
122
+ create_file_handler(join(output_dir, 'onnx-predict.log'))
123
+ update_console_handler(verbose)
124
+ initial_log_messages('onnx_predict')
115
125
 
116
126
  providers = ort.get_available_providers()
117
- logger.info(f'Loaded Onnx runtime, available providers: {providers}.')
127
+ logger.info(f'Loaded ONNX Runtime, available providers: {providers}.')
118
128
 
119
129
  session, options, model_root, hparams, sess_inputs, sess_outputs = load_ort_session(model_path)
120
130
  if hparams is None:
121
- logger.error(f'Error: onnx model does not have required Sonusai hyper-parameters, can not proceed.')
131
+ logger.error(f'Error: ONNX model does not have required SonusAI hyperparameters, cannot proceed.')
122
132
  raise SystemExit(1)
123
133
  if len(sess_inputs) != 1:
124
- logger.error(f'Error: onnx model does not have 1 input, but {len(sess_inputs)}. Exit due to unknown input.')
134
+ logger.error(f'Error: ONNX model does not have 1 input, but {len(sess_inputs)}. Exit due to unknown input.')
125
135
 
126
136
  in0name = sess_inputs[0].name
127
137
  in0type = sess_inputs[0].type
128
- out0name = sess_outputs[0].name
129
138
  out_names = [n.name for n in session.get_outputs()]
130
139
 
131
- from os.path import join, dirname, isdir, normpath, realpath, abspath
132
- from sonusai.utils.asr_manifest_functions import PathInfo
133
- from sonusai.utils import braced_iglob
134
- from sonusai.mixture import MixtureDatabase
140
+ logger.info(f'Read and compiled ONNX model from {model_path}.')
135
141
 
136
- mixdb_path = None
137
- entries = None
138
- if len(datapaths) == 1 and isdir(datapaths[0]): # Assume it's a single path to sonusai mixdb subdir
139
- in_basename = basename(normpath(datapaths[0]))
140
- mixdb_path= datapaths[0]
142
+ if mixdb_path is not None:
143
+ # Assume it's a single path to SonusAI mixdb subdir
141
144
  logger.debug(f'Attempting to load mixture database from {mixdb_path}')
142
145
  mixdb = MixtureDatabase(mixdb_path)
143
- logger.debug(f'Sonusai mixture db load success: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes')
146
+ logger.info(f'SonusAI mixdb: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes')
144
147
  p_mixids = mixdb.mixids_to_list(mixids)
145
148
  if len(p_mixids) != mixdb.num_mixtures:
146
149
  logger.info(f'Processing a subset of {p_mixids} from available mixtures.')
147
- else: # search all datapaths for .wav, .flac (or whatever is specified in include)
148
- in_basename = ''
149
- entries: list[PathInfo] = []
150
- for p in datapaths:
150
+ else:
151
+ for p in data_paths:
151
152
  location = join(realpath(abspath(p)), '**', include)
152
153
  logger.debug(f'Processing {location}')
153
154
  for file in braced_iglob(pathname=location, recursive=True):
154
155
  name = file
155
156
  entries.append(PathInfo(abs_path=file, audio_filepath=name))
157
+ logger.info(f'{len(data_paths)} data paths specified, found {len(entries)} audio files.')
156
158
 
157
- from sonusai.utils import create_ts_name
158
- from os import makedirs
159
- from sonusai import create_file_handler
160
- from sonusai import initial_log_messages
161
- from sonusai import update_console_handler
162
- output_dir = create_ts_name('opredict-' + in_basename)
163
- makedirs(output_dir, exist_ok=True)
164
- # Setup logging file
165
- create_file_handler(join(output_dir, 'onnx-predict.log'))
166
- update_console_handler(verbose)
167
- initial_log_messages('onnx_predict')
168
- # Reprint some info messages
169
- logger.info(f'Loaded OnnxRuntime, available providers: {providers}.')
170
- logger.info(f'Read and compiled onnx model from {model_path}.')
171
- if len(datapaths) == 1 and isdir(datapaths[0]): # Assume it's a single path to sonusai mixdb subdir
172
- logger.info(f'Loaded mixture database from {datapaths}')
173
- logger.info(f'Sonusai mixture db: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes')
174
- else:
175
- logger.info(f'{len(datapaths)} data paths specified, found {len(entries)} audio files.')
176
159
  if in0type.find('float16') != -1:
177
160
  model_is_fp16 = True
178
161
  logger.info(f'Detected input of float16, converting all feature inputs to that type.')
179
162
  else:
180
163
  model_is_fp16 = False
181
164
 
182
-
183
- if mixdb_path is not None: # mixdb input
165
+ if mixdb_path is not None and hparams['batch_size'] == 1:
166
+ # mixdb input
184
167
  # Assume (of course) that mixdb feature, etc. is what model expects
185
- if hparams["feature"] != mixdb.feature:
168
+ if hparams['feature'] != mixdb.feature:
186
169
  logger.warning(f'Mixture feature does not match model feature, this inference run may fail.')
187
- feature_mode = mixdb.feature # no choice, can't use hparams["feature"] since it's different than the mixdb
188
-
189
- #if hparams["num_classes"] != mixdb.num_classes: # needs to be i.e. mixdb.feature_parameters
190
- #if mixdb.num_classes != model_num_classes:
191
- # logger.error(f'Feature parameters in mixture db {mixdb.num_classes} does not match num_classes in model {inp0shape[-1]}')
192
- # raise SystemExit(1)
193
-
194
- from sonusai.utils import reshape_inputs
195
- from sonusai.utils import reshape_outputs
196
- from sonusai.mixture import get_audio_from_feature
197
- from sonusai.utils import write_wav
198
- import numpy as np
199
- import h5py
200
- if hparams["batch_size"] == 1:
201
- for mixid in p_mixids:
202
- feature, _ = mixdb.mixture_ft(mixid) # frames x stride x feature_params
203
- if hparams["timesteps"] == 0:
204
- tsteps = 0 # no timestep dim, reshape will take care
205
- else:
206
- tsteps = feature.shape[0] # fit frames into timestep dimension (TSE mode)
207
- feature, _ = reshape_inputs(feature=feature,
208
- batch_size=1,
209
- timesteps=tsteps,
210
- flatten=hparams["flatten"],
211
- add1ch=hparams["add1ch"])
212
- if model_is_fp16:
213
- feature = np.float16(feature)
214
- # run inference, ort session wants i.e. batch x tsteps x feat_params, outputs numpy BxTxFP or BxFP
215
- predict = session.run(out_names, {in0name: feature})[0]
216
- #predict, _ = reshape_outputs(predict=predict[0], timesteps=frames) # frames x feat_params
217
- output_fname = join(output_dir, mixdb.mixtures[mixid].name)
218
- with h5py.File(output_fname, 'a') as f:
219
- if 'predict' in f:
220
- del f['predict']
221
- f.create_dataset('predict', data=predict)
222
- if writewav: # note only makes sense if model is predicting audio, i.e. tstep dimension exists
223
- # predict_audio wants [frames, channels, feature_parameters] equiv. to tsteps, batch, bins
224
- predict = np.transpose(predict, [1, 0, 2])
225
- predict_audio = get_audio_from_feature(feature=predict, feature_mode=feature_mode)
226
- owav_name = splitext(output_fname)[0]+'_predict.wav'
227
- write_wav(owav_name, predict_audio)
228
-
229
-
230
- #
231
- # # sampler = None
232
- # # p_datagen = TorchFromMixtureDatabase(mixdb=mixdb,
233
- # # mixids=p_mixids,
234
- # # batch_size=batch_size,
235
- # # cut_len=0,
236
- # # flatten=model.flatten,
237
- # # add1ch=model.add1ch,
238
- # # random_cut=False,
239
- # # sampler=sampler,
240
- # # drop_last=False,
241
- # # num_workers=dlcpu)
242
- #
243
- # # Info needed to set up inverse transform
244
- # half = model.num_classes // 2
245
- # fg = FeatureGenerator(feature_mode=feature,
246
- # num_classes=model.num_classes,
247
- # truth_mutex=model.truth_mutex)
248
- # itf = TorchInverseTransform(N=fg.itransform_N,
249
- # R=fg.itransform_R,
250
- # bin_start=fg.bin_start,
251
- # bin_end=fg.bin_end,
252
- # ttype=fg.itransform_ttype)
253
- #
254
- # enable_truth_wav = False
255
- # enable_mix_wav = False
256
- # if wavdbg:
257
- # if mixdb.target_files[0].truth_settings[0].function == 'target_mixture_f':
258
- # enable_mix_wav = True
259
- # enable_truth_wav = True
260
- # elif mixdb.target_files[0].truth_settings[0].function == 'target_f':
261
- # enable_truth_wav = True
262
- #
263
- # if reset:
264
- # logger.info(f'Running {mixdb.num_mixtures} mixtures individually with model reset ...')
265
- # for idx, val in enumerate(p_datagen):
266
- # # truth = val[1]
267
- # feature = val[0]
268
- # with torch.no_grad():
269
- # ypred = model(feature)
270
- # output_name = join(output_dir, mixdb.mixtures[idx].name)
271
- # pdat = ypred.detach().numpy()
272
- # if timesteps > 0:
273
- # logger.debug(f'In and out tsteps: {feature.shape[1]},{pdat.shape[1]}')
274
- # logger.debug(f'Writing predict shape {pdat.shape} to {output_name}')
275
- # with h5py.File(output_name, 'a') as f:
276
- # if 'predict' in f:
277
- # del f['predict']
278
- # f.create_dataset('predict', data=pdat)
279
- #
280
- # if wavdbg:
281
- # owav_base = splitext(output_name)[0]
282
- # tmp = torch.complex(ypred[..., :half], ypred[..., half:]).permute(2, 0, 1).detach()
283
- # itf.reset()
284
- # predwav, _ = itf.execute_all(tmp)
285
- # # predwav, _ = calculate_audio_from_transform(tmp.numpy(), itf, trim=True)
286
- # write_wav(owav_base + '.wav', predwav.permute([1, 0]).numpy(), 16000)
287
- # if enable_truth_wav:
288
- # # Note this support truth type target_f and target_mixture_f
289
- # tmp = torch.complex(val[0][..., :half], val[0][..., half:2 * half]).permute(2, 0, 1).detach()
290
- # itf.reset()
291
- # truthwav, _ = itf.execute_all(tmp)
292
- # write_wav(owav_base + '_truth.wav', truthwav.permute([1, 0]).numpy(), 16000)
293
- #
294
- # if enable_mix_wav:
295
- # tmp = torch.complex(val[0][..., 2 * half:3 * half], val[0][..., 3 * half:]).permute(2, 0, 1)
296
- # itf.reset()
297
- # mixwav, _ = itf.execute_all(tmp.detach())
298
- # write_wav(owav_base + '_mix.wav', mixwav.permute([1, 0]).numpy(), 16000)
299
- #
300
- #
301
- #
302
- #
303
- #
304
- #
305
- #
306
- #
307
- #
308
- #
309
- #
310
- #
311
- #
312
- #
313
- #
314
- #
315
- #
316
- # from os import makedirs
317
- # from os.path import isdir
318
- # from os.path import join
319
- # from os.path import splitext
320
- #
321
- # import h5py
322
- # import onnxruntime as rt
323
- # import numpy as np
324
- #
325
- # from sonusai.mixture import Feature
326
- # from sonusai.mixture import Predict
327
- # from sonusai.utils import SonusAIMetaData
328
- # from sonusai import create_file_handler
329
- # from sonusai import initial_log_messages
330
- # from sonusai import update_console_handler
331
- # from sonusai.mixture import MixtureDatabase
332
- # from sonusai.mixture import get_feature_from_audio
333
- # from sonusai.mixture import read_audio
334
- # from sonusai.utils import create_ts_name
335
- # from sonusai.utils import get_frames_per_batch
336
- # from sonusai.utils import get_sonusai_metadata
337
- #
338
- # output_dir = create_ts_name('ovpredict')
339
- # makedirs(output_dir, exist_ok=True)
340
- #
341
- #
342
- #
343
- # model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
344
- # model_metadata = get_sonusai_metadata(model)
345
- #
346
- # batch_size = model_metadata.input_shape[0]
347
- # if model_metadata.timestep:
348
- # timesteps = model_metadata.input_shape[1]
349
- # else:
350
- # timesteps = 0
351
- # num_classes = model_metadata.output_shape[-1]
352
- #
353
- # frames_per_batch = get_frames_per_batch(batch_size, timesteps)
354
- #
355
- # logger.info('')
356
- # logger.info(f'feature {model_metadata.feature}')
357
- # logger.info(f'num_classes {num_classes}')
358
- # logger.info(f'batch_size {batch_size}')
359
- # logger.info(f'timesteps {timesteps}')
360
- # logger.info(f'flatten {model_metadata.flattened}')
361
- # logger.info(f'add1ch {model_metadata.channel}')
362
- # logger.info(f'truth_mutex {model_metadata.mutex}')
363
- # logger.info(f'input_shape {model_metadata.input_shape}')
364
- # logger.info(f'output_shape {model_metadata.output_shape}')
365
- # logger.info('')
366
- #
367
- # if splitext(entries)[1] == '.wav':
368
- # # Convert WAV to feature data
369
- # logger.info('')
370
- # logger.info(f'Run prediction on {entries}')
371
- # audio = read_audio()
372
- # feature = get_feature_from_audio(audio=audio, feature_mode=model_metadata.feature)
373
- #
374
- # predict = pad_and_predict(feature=feature,
375
- # model_name=model_name,
376
- # model_metadata=model_metadata,
377
- # frames_per_batch=frames_per_batch,
378
- # batch_size=batch_size,
379
- # timesteps=timesteps,
380
- # reset=reset)
381
- #
382
- # output_name = splitext()[0] + '.h5'
383
- # with h5py.File(output_name, 'a') as f:
384
- # if 'feature' in f:
385
- # del f['feature']
386
- # f.create_dataset(name='feature', data=feature)
387
- #
388
- # if 'predict' in f:
389
- # del f['predict']
390
- # f.create_dataset(name='predict', data=predict)
391
- #
392
- # logger.info(f'Saved results to {output_name}')
393
- # return
394
- #
395
- # if not isdir():
396
- # logger.exception(f'Do not know how to process input from {entries}')
397
- # raise SystemExit(1)
398
- #
399
- # mixdb = MixtureDatabase()
400
- #
401
- # if mixdb.feature != model_metadata.feature:
402
- # logger.exception(f'Feature in mixture database does not match feature in model')
403
- # raise SystemExit(1)
404
- #
405
- # mixids = mixdb.mixids_to_list(mixids)
406
- # if reset:
407
- # # reset mode cycles through each file one at a time
408
- # for mixid in mixids:
409
- # feature, _ = mixdb.mixture_ft(mixid)
410
- #
411
- # predict = pad_and_predict(feature=feature,
412
- # model_name=model_name,
413
- # model_metadata=model_metadata,
414
- # frames_per_batch=frames_per_batch,
415
- # batch_size=batch_size,
416
- # timesteps=timesteps,
417
- # reset=reset)
418
- #
419
- # output_name = join(output_dir, mixdb.mixtures[mixid].name)
420
- # with h5py.File(output_name, 'a') as f:
421
- # if 'predict' in f:
422
- # del f['predict']
423
- # f.create_dataset(name='predict', data=predict)
424
- # else:
425
- # features: list[Feature] = []
426
- # file_indices: list[slice] = []
427
- # total_frames = 0
428
- # for mixid in mixids:
429
- # current_feature, _ = mixdb.mixture_ft(mixid)
430
- # current_frames = current_feature.shape[0]
431
- # features.append(current_feature)
432
- # file_indices.append(slice(total_frames, total_frames + current_frames))
433
- # total_frames += current_frames
434
- # feature = np.vstack([features[i] for i in range(len(features))])
435
- #
436
- # predict = pad_and_predict(feature=feature,
437
- # model_name=model_name,
438
- # model_metadata=model_metadata,
439
- # frames_per_batch=frames_per_batch,
440
- # batch_size=batch_size,
441
- # timesteps=timesteps,
442
- # reset=reset)
443
- #
444
- # # Write data to separate files
445
- # for idx, mixid in enumerate(mixids):
446
- # output_name = join(output_dir, mixdb.mixtures[mixid].name)
447
- # with h5py.File(output_name, 'a') as f:
448
- # if 'predict' in f:
449
- # del f['predict']
450
- # f.create_dataset('predict', data=predict[file_indices[idx]])
451
- #
452
- # logger.info(f'Saved results to {output_dir}')
453
- #
454
-
455
- # def pad_and_predict(feature: Feature,
456
- # model_name: str,
457
- # model_metadata: SonusAIMetaData,
458
- # frames_per_batch: int,
459
- # batch_size: int,
460
- # timesteps: int,
461
- # reset: bool) -> Predict:
462
- # import onnxruntime as rt
463
- # import numpy as np
464
- #
465
- # from sonusai.utils import reshape_inputs
466
- # from sonusai.utils import reshape_outputs
467
- #
468
- # frames = feature.shape[0]
469
- # padding = frames_per_batch - frames % frames_per_batch
470
- # feature = np.pad(array=feature, pad_width=((0, padding), (0, 0), (0, 0)))
471
- # feature, _ = reshape_inputs(feature=feature,
472
- # batch_size=batch_size,
473
- # timesteps=timesteps,
474
- # flatten=model_metadata.flattened,
475
- # add1ch=model_metadata.channel)
476
- # sequences = feature.shape[0] // model_metadata.input_shape[0]
477
- # feature = np.reshape(feature, [sequences, *model_metadata.input_shape])
478
- #
479
- # model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
480
- # output_names = [n.name for n in model.get_outputs()]
481
- # input_names = [n.name for n in model.get_inputs()]
482
- #
483
- # predict = []
484
- # for sequence in range(sequences):
485
- # predict.append(model.run(output_names, {input_names[0]: feature[sequence]}))
486
- # if reset:
487
- # model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
488
- #
489
- # predict_arr = np.vstack(predict)
490
- # # Combine [sequences, batch_size, ...] into [frames, ...]
491
- # predict_shape = predict_arr.shape
492
- # predict_arr = np.reshape(predict_arr, [predict_shape[0] * predict_shape[1], *predict_shape[2:]])
493
- # predict_arr, _ = reshape_outputs(predict=predict_arr, timesteps=timesteps)
494
- # predict_arr = predict_arr[:frames, :]
495
- #
496
- # return predict_arr
170
+ # no choice, can't use hparams.feature since it's different from the mixdb
171
+ feature_mode = mixdb.feature
172
+
173
+ for mixid in p_mixids:
174
+ # frames x stride x feature_params
175
+ feature, _ = mixdb.mixture_ft(mixid)
176
+ if hparams['timesteps'] == 0:
177
+ # no timestep dimension, reshape will handle
178
+ timesteps = 0
179
+ else:
180
+ # fit frames into timestep dimension (TSE mode)
181
+ timesteps = feature.shape[0]
182
+
183
+ feature, _ = reshape_inputs(feature=feature,
184
+ batch_size=1,
185
+ timesteps=timesteps,
186
+ flatten=hparams['flatten'],
187
+ add1ch=hparams['add1ch'])
188
+ if model_is_fp16:
189
+ feature = np.float16(feature) # type: ignore
190
+ # run inference, ort session wants i.e. batch x timesteps x feat_params, outputs numpy BxTxFP or BxFP
191
+ predict = session.run(out_names, {in0name: feature})[0]
192
+ # predict, _ = reshape_outputs(predict=predict[0], timesteps=frames) # frames x feat_params
193
+ output_fname = join(output_dir, mixdb.mixtures[mixid].name)
194
+ with h5py.File(output_fname, 'a') as f:
195
+ if 'predict' in f:
196
+ del f['predict']
197
+ f.create_dataset('predict', data=predict)
198
+ if wav:
199
+ # note only makes sense if model is predicting audio, i.e., timestep dimension exists
200
+ # predict_audio wants [frames, channels, feature_parameters] equivalent to timesteps, batch, bins
201
+ predict = np.transpose(predict, [1, 0, 2])
202
+ predict_audio = get_audio_from_feature(feature=predict, feature_mode=feature_mode)
203
+ owav_name = splitext(output_fname)[0] + '_predict.wav'
204
+ write_wav(owav_name, predict_audio)
497
205
 
498
206
 
499
207
  if __name__ == '__main__':
500
- try:
501
- main()
502
- except KeyboardInterrupt:
503
- logger.info('Canceled due to keyboard interrupt')
504
- raise SystemExit(0)
208
+ main()
@@ -134,7 +134,7 @@ def get_mixids_from_target(mixdb: MixtureDatabase,
134
134
  """
135
135
  return get_mixids_from_mixture_field_predicate(mixdb=mixdb,
136
136
  mixids=mixids,
137
- field='target_id',
137
+ field='target_ids',
138
138
  predicate=predicate)
139
139
 
140
140
 
@@ -0,0 +1,3 @@
1
+ from .textgrid import annotate_textgrid
2
+ from .textgrid import create_textgrid
3
+ from .types import TimeAlignedType