sonusai 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +1 -0
- sonusai/audiofe.py +157 -61
- sonusai/calc_metric_spenh-save.py +1334 -0
- sonusai/calc_metric_spenh.py +15 -8
- sonusai/genft.py +15 -6
- sonusai/genmix.py +14 -6
- sonusai/genmixdb.py +14 -6
- sonusai/gentcst.py +13 -6
- sonusai/lsdb.py +15 -5
- sonusai/mkmanifest.py +14 -6
- sonusai/mkwav.py +15 -6
- sonusai/onnx_predict-old.py +240 -0
- sonusai/onnx_predict-save.py +487 -0
- sonusai/onnx_predict.py +446 -182
- sonusai/ovino_predict.py +508 -0
- sonusai/ovino_query_devices.py +47 -0
- sonusai/plot.py +16 -6
- sonusai/post_spenh_targetf.py +13 -6
- sonusai/summarize_metric_spenh.py +71 -0
- sonusai/torchl_onnx-old.py +216 -0
- sonusai/tplot.py +14 -6
- sonusai/utils/onnx_utils.py +128 -39
- {sonusai-0.16.0.dist-info → sonusai-0.17.0.dist-info}/METADATA +1 -1
- {sonusai-0.16.0.dist-info → sonusai-0.17.0.dist-info}/RECORD +26 -19
- {sonusai-0.16.0.dist-info → sonusai-0.17.0.dist-info}/WHEEL +1 -1
- {sonusai-0.16.0.dist-info → sonusai-0.17.0.dist-info}/entry_points.txt +0 -0
sonusai/onnx_predict.py
CHANGED
@@ -1,19 +1,23 @@
|
|
1
|
-
"""sonusai
|
1
|
+
"""sonusai onnx_predict
|
2
2
|
|
3
|
-
usage:
|
3
|
+
usage: onnx_predict [-hvlwr] [--include GLOB] [-i MIXID] MODEL DATA ...
|
4
4
|
|
5
5
|
options:
|
6
6
|
-h, --help
|
7
7
|
-v, --verbose Be verbose.
|
8
|
+
-l, --list-device-details List details of all OpenVINO available devices
|
8
9
|
-i MIXID, --mixid MIXID Mixture ID(s) to generate if input is a mixture database. [default: *].
|
9
|
-
|
10
|
+
--include GLOB Search only files whose base name matches GLOB. [default: *.{wav,flac}].
|
11
|
+
-w, --write-wav Calculate inverse transform of prediction and write .wav files
|
10
12
|
-r, --reset Reset model between each file.
|
11
13
|
|
12
|
-
Run prediction
|
14
|
+
Run prediction (inference) using an onnx model on a SonusAI mixture dataset or audio files from a regex path.
|
15
|
+
The OnnxRuntime (ORT) inference engine is used to execute the inference.
|
13
16
|
|
14
17
|
Inputs:
|
15
|
-
MODEL
|
16
|
-
|
18
|
+
MODEL ONNX model .onnx file of a trained model (weights are expected to be in the file).
|
19
|
+
|
20
|
+
DATA The input data must be one of the following:
|
17
21
|
* WAV
|
18
22
|
Using the given model, generate feature data and run prediction. A model file must be
|
19
23
|
provided. The MIXID is ignored.
|
@@ -22,7 +26,22 @@ Inputs:
|
|
22
26
|
Using the given SonusAI mixture database directory, generate feature and truth data if not found.
|
23
27
|
Run prediction. The MIXID is required.
|
24
28
|
|
25
|
-
|
29
|
+
|
30
|
+
Note there are multiple ways to process model prediction over multiple audio data files:
|
31
|
+
1. TSE (timestep single extension): mixture transform frames are fit into the timestep dimension and the model run as
|
32
|
+
a single inference call. If batch_size is > 1 then run multiple mixtures in one call with shorter mixtures
|
33
|
+
zero-padded to the size of the largest mixture.
|
34
|
+
2. TME (timestep multi-extension): mixture is split into multiple timesteps, i.e. batch[0] is starting timesteps, ...
|
35
|
+
Note that batches are run independently, thus sequential state from one set of timesteps to the next will not be
|
36
|
+
maintained, thus results for such models (i.e. conv, LSTMs, in the tstep dimension) would not match using TSE mode.
|
37
|
+
|
38
|
+
TBD not sure below make sense, need to continue ??
|
39
|
+
2. BSE (batch single extension): mixture transform frames are fit into the batch dimension. This make sense only if
|
40
|
+
independent predictions are made on each frame w/o considering previous frames (i.e.
|
41
|
+
timesteps=1 or there is no timestep dimension in the model (timesteps=0).
|
42
|
+
3.classification
|
43
|
+
|
44
|
+
Outputs the following to ovpredict-<TIMESTAMP> directory:
|
26
45
|
<id>.h5
|
27
46
|
dataset: predict
|
28
47
|
onnx_predict.log
|
@@ -30,9 +49,51 @@ Outputs the following to opredict-<TIMESTAMP> directory:
|
|
30
49
|
"""
|
31
50
|
|
32
51
|
from sonusai import logger
|
33
|
-
from
|
34
|
-
from
|
35
|
-
|
52
|
+
from typing import Any, List, Optional, Tuple
|
53
|
+
from os.path import basename, splitext, exists, isfile
|
54
|
+
import onnxruntime as ort
|
55
|
+
import onnx
|
56
|
+
from onnx import ValueInfoProto
|
57
|
+
|
58
|
+
|
59
|
+
def load_ort_session(model_path, providers=['CPUExecutionProvider']):
|
60
|
+
if exists(model_path) and isfile(model_path):
|
61
|
+
model_basename = basename(model_path)
|
62
|
+
model_root = splitext(model_basename)[0]
|
63
|
+
logger.info(f'Importing model from {model_basename}')
|
64
|
+
try:
|
65
|
+
session = ort.InferenceSession(model_path, providers=providers)
|
66
|
+
options = ort.SessionOptions()
|
67
|
+
except Exception as e:
|
68
|
+
logger.exception(f'Error: could not load onnx model from {model_path}: {e}')
|
69
|
+
raise SystemExit(1)
|
70
|
+
else:
|
71
|
+
logger.exception(f'Error: model file does not exist: {model_path}')
|
72
|
+
raise SystemExit(1)
|
73
|
+
|
74
|
+
logger.info(f'Opened session with provider options: {session._provider_options}.')
|
75
|
+
try:
|
76
|
+
meta = session.get_modelmeta()
|
77
|
+
hparams = eval(meta.custom_metadata_map["hparams"])
|
78
|
+
logger.info(f'Sonusai hyper-parameter metadata was found in model with {len(hparams)} parameters, '
|
79
|
+
f'checking for required ones ...')
|
80
|
+
# Print to log here will fail if required parameters not available.
|
81
|
+
logger.info(f'feature {hparams["feature"]}')
|
82
|
+
logger.info(f'batch_size {hparams["batch_size"]}')
|
83
|
+
logger.info(f'timesteps {hparams["timesteps"]}')
|
84
|
+
logger.info(f'flatten, add1ch {hparams["flatten"]}, {hparams["add1ch"]}')
|
85
|
+
logger.info(f'truth_mutex {hparams["truth_mutex"]}')
|
86
|
+
except:
|
87
|
+
hparams = None
|
88
|
+
logger.warning(f'Warning: onnx model does not have required SonusAI hyper-parameters.')
|
89
|
+
|
90
|
+
inputs = session.get_inputs()
|
91
|
+
outputs = session.get_outputs()
|
92
|
+
|
93
|
+
#in_names = [n.name for n in session.get_inputs()]
|
94
|
+
#out_names = [n.name for n in session.get_outputs()]
|
95
|
+
|
96
|
+
return session, options, model_root, hparams, inputs, outputs
|
36
97
|
|
37
98
|
|
38
99
|
def main() -> None:
|
@@ -44,192 +105,395 @@ def main() -> None:
|
|
44
105
|
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
45
106
|
|
46
107
|
verbose = args['--verbose']
|
108
|
+
listdd = args['--list-device-details']
|
109
|
+
writewav= args['--write-wav']
|
47
110
|
mixids = args['--mixid']
|
48
|
-
model_name = args['--model']
|
49
111
|
reset = args['--reset']
|
50
|
-
|
112
|
+
include = args['--include']
|
113
|
+
model_path = args['MODEL']
|
114
|
+
datapaths = args['DATA']
|
51
115
|
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
116
|
+
providers = ort.get_available_providers()
|
117
|
+
logger.info(f'Loaded Onnx runtime, available providers: {providers}.')
|
118
|
+
|
119
|
+
session, options, model_root, hparams, sess_inputs, sess_outputs = load_ort_session(model_path)
|
120
|
+
if hparams is None:
|
121
|
+
logger.error(f'Error: onnx model does not have required Sonusai hyper-parameters, can not proceed.')
|
122
|
+
raise SystemExit(1)
|
123
|
+
if len(sess_inputs) != 1:
|
124
|
+
logger.error(f'Error: onnx model does not have 1 input, but {len(sess_inputs)}. Exit due to unknown input.')
|
125
|
+
|
126
|
+
in0name = sess_inputs[0].name
|
127
|
+
in0type = sess_inputs[0].type
|
128
|
+
out0name = sess_outputs[0].name
|
129
|
+
out_names = [n.name for n in session.get_outputs()]
|
56
130
|
|
57
|
-
import
|
58
|
-
|
59
|
-
|
131
|
+
from os.path import join, dirname, isdir, normpath, realpath, abspath
|
132
|
+
from sonusai.utils.asr_manifest_functions import PathInfo
|
133
|
+
from sonusai.utils import braced_iglob
|
134
|
+
from sonusai.mixture import MixtureDatabase
|
135
|
+
|
136
|
+
mixdb_path = None
|
137
|
+
entries = None
|
138
|
+
if len(datapaths) == 1 and isdir(datapaths[0]): # Assume it's a single path to sonusai mixdb subdir
|
139
|
+
in_basename = basename(normpath(datapaths[0]))
|
140
|
+
mixdb_path= datapaths[0]
|
141
|
+
logger.debug(f'Attempting to load mixture database from {mixdb_path}')
|
142
|
+
mixdb = MixtureDatabase(mixdb_path)
|
143
|
+
logger.debug(f'Sonusai mixture db load success: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes')
|
144
|
+
p_mixids = mixdb.mixids_to_list(mixids)
|
145
|
+
if len(p_mixids) != mixdb.num_mixtures:
|
146
|
+
logger.info(f'Processing a subset of {p_mixids} from available mixtures.')
|
147
|
+
else: # search all datapaths for .wav, .flac (or whatever is specified in include)
|
148
|
+
in_basename = ''
|
149
|
+
entries: list[PathInfo] = []
|
150
|
+
for p in datapaths:
|
151
|
+
location = join(realpath(abspath(p)), '**', include)
|
152
|
+
logger.debug(f'Processing {location}')
|
153
|
+
for file in braced_iglob(pathname=location, recursive=True):
|
154
|
+
name = file
|
155
|
+
entries.append(PathInfo(abs_path=file, audio_filepath=name))
|
60
156
|
|
157
|
+
from sonusai.utils import create_ts_name
|
158
|
+
from os import makedirs
|
61
159
|
from sonusai import create_file_handler
|
62
160
|
from sonusai import initial_log_messages
|
63
161
|
from sonusai import update_console_handler
|
64
|
-
|
65
|
-
from sonusai.mixture import get_feature_from_audio
|
66
|
-
from sonusai.mixture import read_audio
|
67
|
-
from sonusai.utils import create_ts_name
|
68
|
-
from sonusai.utils import get_frames_per_batch
|
69
|
-
from sonusai.utils import get_sonusai_metadata
|
70
|
-
|
71
|
-
output_dir = create_ts_name('opredict')
|
162
|
+
output_dir = create_ts_name('opredict-' + in_basename)
|
72
163
|
makedirs(output_dir, exist_ok=True)
|
73
|
-
|
74
164
|
# Setup logging file
|
75
|
-
create_file_handler(join(output_dir, '
|
165
|
+
create_file_handler(join(output_dir, 'onnx-predict.log'))
|
76
166
|
update_console_handler(verbose)
|
77
167
|
initial_log_messages('onnx_predict')
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
timesteps = model_metadata.input_shape[1]
|
168
|
+
# Reprint some info messages
|
169
|
+
logger.info(f'Loaded OnnxRuntime, available providers: {providers}.')
|
170
|
+
logger.info(f'Read and compiled onnx model from {model_path}.')
|
171
|
+
if len(datapaths) == 1 and isdir(datapaths[0]): # Assume it's a single path to sonusai mixdb subdir
|
172
|
+
logger.info(f'Loaded mixture database from {datapaths}')
|
173
|
+
logger.info(f'Sonusai mixture db: found {mixdb.num_mixtures} mixtures with {mixdb.num_classes} classes')
|
85
174
|
else:
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
logger.info('')
|
92
|
-
logger.info(f'feature {model_metadata.feature}')
|
93
|
-
logger.info(f'num_classes {num_classes}')
|
94
|
-
logger.info(f'batch_size {batch_size}')
|
95
|
-
logger.info(f'timesteps {timesteps}')
|
96
|
-
logger.info(f'flatten {model_metadata.flattened}')
|
97
|
-
logger.info(f'add1ch {model_metadata.channel}')
|
98
|
-
logger.info(f'truth_mutex {model_metadata.mutex}')
|
99
|
-
logger.info(f'input_shape {model_metadata.input_shape}')
|
100
|
-
logger.info(f'output_shape {model_metadata.output_shape}')
|
101
|
-
logger.info('')
|
102
|
-
|
103
|
-
if splitext(input_name)[1] == '.wav':
|
104
|
-
# Convert WAV to feature data
|
105
|
-
logger.info('')
|
106
|
-
logger.info(f'Run prediction on {input_name}')
|
107
|
-
audio = read_audio(input_name)
|
108
|
-
feature = get_feature_from_audio(audio=audio, feature_mode=model_metadata.feature)
|
109
|
-
|
110
|
-
predict = pad_and_predict(feature=feature,
|
111
|
-
model_name=model_name,
|
112
|
-
model_metadata=model_metadata,
|
113
|
-
frames_per_batch=frames_per_batch,
|
114
|
-
batch_size=batch_size,
|
115
|
-
timesteps=timesteps,
|
116
|
-
reset=reset)
|
117
|
-
|
118
|
-
output_name = splitext(input_name)[0] + '.h5'
|
119
|
-
with h5py.File(output_name, 'a') as f:
|
120
|
-
if 'feature' in f:
|
121
|
-
del f['feature']
|
122
|
-
f.create_dataset(name='feature', data=feature)
|
123
|
-
|
124
|
-
if 'predict' in f:
|
125
|
-
del f['predict']
|
126
|
-
f.create_dataset(name='predict', data=predict)
|
127
|
-
|
128
|
-
logger.info(f'Saved results to {output_name}')
|
129
|
-
return
|
130
|
-
|
131
|
-
if not isdir(input_name):
|
132
|
-
logger.exception(f'Do not know how to process input from {input_name}')
|
133
|
-
raise SystemExit(1)
|
134
|
-
|
135
|
-
mixdb = MixtureDatabase(input_name)
|
136
|
-
|
137
|
-
if mixdb.feature != model_metadata.feature:
|
138
|
-
logger.exception(f'Feature in mixture database does not match feature in model')
|
139
|
-
raise SystemExit(1)
|
140
|
-
|
141
|
-
mixids = mixdb.mixids_to_list(mixids)
|
142
|
-
if reset:
|
143
|
-
# reset mode cycles through each file one at a time
|
144
|
-
for mixid in mixids:
|
145
|
-
feature, _ = mixdb.mixture_ft(mixid)
|
146
|
-
|
147
|
-
predict = pad_and_predict(feature=feature,
|
148
|
-
model_name=model_name,
|
149
|
-
model_metadata=model_metadata,
|
150
|
-
frames_per_batch=frames_per_batch,
|
151
|
-
batch_size=batch_size,
|
152
|
-
timesteps=timesteps,
|
153
|
-
reset=reset)
|
154
|
-
|
155
|
-
output_name = join(output_dir, mixdb.mixtures[mixid].name)
|
156
|
-
with h5py.File(output_name, 'a') as f:
|
157
|
-
if 'predict' in f:
|
158
|
-
del f['predict']
|
159
|
-
f.create_dataset(name='predict', data=predict)
|
175
|
+
logger.info(f'{len(datapaths)} data paths specified, found {len(entries)} audio files.')
|
176
|
+
if in0type.find('float16') != -1:
|
177
|
+
model_is_fp16 = True
|
178
|
+
logger.info(f'Detected input of float16, converting all feature inputs to that type.')
|
160
179
|
else:
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
if
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
#
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
180
|
+
model_is_fp16 = False
|
181
|
+
|
182
|
+
|
183
|
+
if mixdb_path is not None: # mixdb input
|
184
|
+
# Assume (of course) that mixdb feature, etc. is what model expects
|
185
|
+
if hparams["feature"] != mixdb.feature:
|
186
|
+
logger.warning(f'Mixture feature does not match model feature, this inference run may fail.')
|
187
|
+
feature_mode = mixdb.feature # no choice, can't use hparams["feature"] since it's different than the mixdb
|
188
|
+
|
189
|
+
#if hparams["num_classes"] != mixdb.num_classes: # needs to be i.e. mixdb.feature_parameters
|
190
|
+
#if mixdb.num_classes != model_num_classes:
|
191
|
+
# logger.error(f'Feature parameters in mixture db {mixdb.num_classes} does not match num_classes in model {inp0shape[-1]}')
|
192
|
+
# raise SystemExit(1)
|
193
|
+
|
194
|
+
from sonusai.utils import reshape_inputs
|
195
|
+
from sonusai.utils import reshape_outputs
|
196
|
+
from sonusai.mixture import get_audio_from_feature
|
197
|
+
from sonusai.utils import write_wav
|
198
|
+
import numpy as np
|
199
|
+
import h5py
|
200
|
+
if hparams["batch_size"] == 1:
|
201
|
+
for mixid in p_mixids:
|
202
|
+
feature, _ = mixdb.mixture_ft(mixid) # frames x stride x feature_params
|
203
|
+
if hparams["timesteps"] == 0:
|
204
|
+
tsteps = 0 # no timestep dim, reshape will take care
|
205
|
+
else:
|
206
|
+
tsteps = feature.shape[0] # fit frames into timestep dimension (TSE mode)
|
207
|
+
feature, _ = reshape_inputs(feature=feature,
|
208
|
+
batch_size=1,
|
209
|
+
timesteps=tsteps,
|
210
|
+
flatten=hparams["flatten"],
|
211
|
+
add1ch=hparams["add1ch"])
|
212
|
+
if model_is_fp16:
|
213
|
+
feature = np.float16(feature)
|
214
|
+
# run inference, ort session wants i.e. batch x tsteps x feat_params, outputs numpy BxTxFP or BxFP
|
215
|
+
predict = session.run(out_names, {in0name: feature})[0]
|
216
|
+
#predict, _ = reshape_outputs(predict=predict[0], timesteps=frames) # frames x feat_params
|
217
|
+
output_fname = join(output_dir, mixdb.mixtures[mixid].name)
|
218
|
+
with h5py.File(output_fname, 'a') as f:
|
219
|
+
if 'predict' in f:
|
220
|
+
del f['predict']
|
221
|
+
f.create_dataset('predict', data=predict)
|
222
|
+
if writewav: # note only makes sense if model is predicting audio, i.e. tstep dimension exists
|
223
|
+
# predict_audio wants [frames, channels, feature_parameters] equiv. to tsteps, batch, bins
|
224
|
+
predict = np.transpose(predict, [1, 0, 2])
|
225
|
+
predict_audio = get_audio_from_feature(feature=predict, feature_mode=feature_mode)
|
226
|
+
owav_name = splitext(output_fname)[0]+'_predict.wav'
|
227
|
+
write_wav(owav_name, predict_audio)
|
228
|
+
|
229
|
+
|
230
|
+
#
|
231
|
+
# # sampler = None
|
232
|
+
# # p_datagen = TorchFromMixtureDatabase(mixdb=mixdb,
|
233
|
+
# # mixids=p_mixids,
|
234
|
+
# # batch_size=batch_size,
|
235
|
+
# # cut_len=0,
|
236
|
+
# # flatten=model.flatten,
|
237
|
+
# # add1ch=model.add1ch,
|
238
|
+
# # random_cut=False,
|
239
|
+
# # sampler=sampler,
|
240
|
+
# # drop_last=False,
|
241
|
+
# # num_workers=dlcpu)
|
242
|
+
#
|
243
|
+
# # Info needed to set up inverse transform
|
244
|
+
# half = model.num_classes // 2
|
245
|
+
# fg = FeatureGenerator(feature_mode=feature,
|
246
|
+
# num_classes=model.num_classes,
|
247
|
+
# truth_mutex=model.truth_mutex)
|
248
|
+
# itf = TorchInverseTransform(N=fg.itransform_N,
|
249
|
+
# R=fg.itransform_R,
|
250
|
+
# bin_start=fg.bin_start,
|
251
|
+
# bin_end=fg.bin_end,
|
252
|
+
# ttype=fg.itransform_ttype)
|
253
|
+
#
|
254
|
+
# enable_truth_wav = False
|
255
|
+
# enable_mix_wav = False
|
256
|
+
# if wavdbg:
|
257
|
+
# if mixdb.target_files[0].truth_settings[0].function == 'target_mixture_f':
|
258
|
+
# enable_mix_wav = True
|
259
|
+
# enable_truth_wav = True
|
260
|
+
# elif mixdb.target_files[0].truth_settings[0].function == 'target_f':
|
261
|
+
# enable_truth_wav = True
|
262
|
+
#
|
263
|
+
# if reset:
|
264
|
+
# logger.info(f'Running {mixdb.num_mixtures} mixtures individually with model reset ...')
|
265
|
+
# for idx, val in enumerate(p_datagen):
|
266
|
+
# # truth = val[1]
|
267
|
+
# feature = val[0]
|
268
|
+
# with torch.no_grad():
|
269
|
+
# ypred = model(feature)
|
270
|
+
# output_name = join(output_dir, mixdb.mixtures[idx].name)
|
271
|
+
# pdat = ypred.detach().numpy()
|
272
|
+
# if timesteps > 0:
|
273
|
+
# logger.debug(f'In and out tsteps: {feature.shape[1]},{pdat.shape[1]}')
|
274
|
+
# logger.debug(f'Writing predict shape {pdat.shape} to {output_name}')
|
275
|
+
# with h5py.File(output_name, 'a') as f:
|
276
|
+
# if 'predict' in f:
|
277
|
+
# del f['predict']
|
278
|
+
# f.create_dataset('predict', data=pdat)
|
279
|
+
#
|
280
|
+
# if wavdbg:
|
281
|
+
# owav_base = splitext(output_name)[0]
|
282
|
+
# tmp = torch.complex(ypred[..., :half], ypred[..., half:]).permute(2, 0, 1).detach()
|
283
|
+
# itf.reset()
|
284
|
+
# predwav, _ = itf.execute_all(tmp)
|
285
|
+
# # predwav, _ = calculate_audio_from_transform(tmp.numpy(), itf, trim=True)
|
286
|
+
# write_wav(owav_base + '.wav', predwav.permute([1, 0]).numpy(), 16000)
|
287
|
+
# if enable_truth_wav:
|
288
|
+
# # Note this support truth type target_f and target_mixture_f
|
289
|
+
# tmp = torch.complex(val[0][..., :half], val[0][..., half:2 * half]).permute(2, 0, 1).detach()
|
290
|
+
# itf.reset()
|
291
|
+
# truthwav, _ = itf.execute_all(tmp)
|
292
|
+
# write_wav(owav_base + '_truth.wav', truthwav.permute([1, 0]).numpy(), 16000)
|
293
|
+
#
|
294
|
+
# if enable_mix_wav:
|
295
|
+
# tmp = torch.complex(val[0][..., 2 * half:3 * half], val[0][..., 3 * half:]).permute(2, 0, 1)
|
296
|
+
# itf.reset()
|
297
|
+
# mixwav, _ = itf.execute_all(tmp.detach())
|
298
|
+
# write_wav(owav_base + '_mix.wav', mixwav.permute([1, 0]).numpy(), 16000)
|
299
|
+
#
|
300
|
+
#
|
301
|
+
#
|
302
|
+
#
|
303
|
+
#
|
304
|
+
#
|
305
|
+
#
|
306
|
+
#
|
307
|
+
#
|
308
|
+
#
|
309
|
+
#
|
310
|
+
#
|
311
|
+
#
|
312
|
+
#
|
313
|
+
#
|
314
|
+
#
|
315
|
+
#
|
316
|
+
# from os import makedirs
|
317
|
+
# from os.path import isdir
|
318
|
+
# from os.path import join
|
319
|
+
# from os.path import splitext
|
320
|
+
#
|
321
|
+
# import h5py
|
322
|
+
# import onnxruntime as rt
|
323
|
+
# import numpy as np
|
324
|
+
#
|
325
|
+
# from sonusai.mixture import Feature
|
326
|
+
# from sonusai.mixture import Predict
|
327
|
+
# from sonusai.utils import SonusAIMetaData
|
328
|
+
# from sonusai import create_file_handler
|
329
|
+
# from sonusai import initial_log_messages
|
330
|
+
# from sonusai import update_console_handler
|
331
|
+
# from sonusai.mixture import MixtureDatabase
|
332
|
+
# from sonusai.mixture import get_feature_from_audio
|
333
|
+
# from sonusai.mixture import read_audio
|
334
|
+
# from sonusai.utils import create_ts_name
|
335
|
+
# from sonusai.utils import get_frames_per_batch
|
336
|
+
# from sonusai.utils import get_sonusai_metadata
|
337
|
+
#
|
338
|
+
# output_dir = create_ts_name('ovpredict')
|
339
|
+
# makedirs(output_dir, exist_ok=True)
|
340
|
+
#
|
341
|
+
#
|
342
|
+
#
|
343
|
+
# model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
|
344
|
+
# model_metadata = get_sonusai_metadata(model)
|
345
|
+
#
|
346
|
+
# batch_size = model_metadata.input_shape[0]
|
347
|
+
# if model_metadata.timestep:
|
348
|
+
# timesteps = model_metadata.input_shape[1]
|
349
|
+
# else:
|
350
|
+
# timesteps = 0
|
351
|
+
# num_classes = model_metadata.output_shape[-1]
|
352
|
+
#
|
353
|
+
# frames_per_batch = get_frames_per_batch(batch_size, timesteps)
|
354
|
+
#
|
355
|
+
# logger.info('')
|
356
|
+
# logger.info(f'feature {model_metadata.feature}')
|
357
|
+
# logger.info(f'num_classes {num_classes}')
|
358
|
+
# logger.info(f'batch_size {batch_size}')
|
359
|
+
# logger.info(f'timesteps {timesteps}')
|
360
|
+
# logger.info(f'flatten {model_metadata.flattened}')
|
361
|
+
# logger.info(f'add1ch {model_metadata.channel}')
|
362
|
+
# logger.info(f'truth_mutex {model_metadata.mutex}')
|
363
|
+
# logger.info(f'input_shape {model_metadata.input_shape}')
|
364
|
+
# logger.info(f'output_shape {model_metadata.output_shape}')
|
365
|
+
# logger.info('')
|
366
|
+
#
|
367
|
+
# if splitext(entries)[1] == '.wav':
|
368
|
+
# # Convert WAV to feature data
|
369
|
+
# logger.info('')
|
370
|
+
# logger.info(f'Run prediction on {entries}')
|
371
|
+
# audio = read_audio()
|
372
|
+
# feature = get_feature_from_audio(audio=audio, feature_mode=model_metadata.feature)
|
373
|
+
#
|
374
|
+
# predict = pad_and_predict(feature=feature,
|
375
|
+
# model_name=model_name,
|
376
|
+
# model_metadata=model_metadata,
|
377
|
+
# frames_per_batch=frames_per_batch,
|
378
|
+
# batch_size=batch_size,
|
379
|
+
# timesteps=timesteps,
|
380
|
+
# reset=reset)
|
381
|
+
#
|
382
|
+
# output_name = splitext()[0] + '.h5'
|
383
|
+
# with h5py.File(output_name, 'a') as f:
|
384
|
+
# if 'feature' in f:
|
385
|
+
# del f['feature']
|
386
|
+
# f.create_dataset(name='feature', data=feature)
|
387
|
+
#
|
388
|
+
# if 'predict' in f:
|
389
|
+
# del f['predict']
|
390
|
+
# f.create_dataset(name='predict', data=predict)
|
391
|
+
#
|
392
|
+
# logger.info(f'Saved results to {output_name}')
|
393
|
+
# return
|
394
|
+
#
|
395
|
+
# if not isdir():
|
396
|
+
# logger.exception(f'Do not know how to process input from {entries}')
|
397
|
+
# raise SystemExit(1)
|
398
|
+
#
|
399
|
+
# mixdb = MixtureDatabase()
|
400
|
+
#
|
401
|
+
# if mixdb.feature != model_metadata.feature:
|
402
|
+
# logger.exception(f'Feature in mixture database does not match feature in model')
|
403
|
+
# raise SystemExit(1)
|
404
|
+
#
|
405
|
+
# mixids = mixdb.mixids_to_list(mixids)
|
406
|
+
# if reset:
|
407
|
+
# # reset mode cycles through each file one at a time
|
408
|
+
# for mixid in mixids:
|
409
|
+
# feature, _ = mixdb.mixture_ft(mixid)
|
410
|
+
#
|
411
|
+
# predict = pad_and_predict(feature=feature,
|
412
|
+
# model_name=model_name,
|
413
|
+
# model_metadata=model_metadata,
|
414
|
+
# frames_per_batch=frames_per_batch,
|
415
|
+
# batch_size=batch_size,
|
416
|
+
# timesteps=timesteps,
|
417
|
+
# reset=reset)
|
418
|
+
#
|
419
|
+
# output_name = join(output_dir, mixdb.mixtures[mixid].name)
|
420
|
+
# with h5py.File(output_name, 'a') as f:
|
421
|
+
# if 'predict' in f:
|
422
|
+
# del f['predict']
|
423
|
+
# f.create_dataset(name='predict', data=predict)
|
424
|
+
# else:
|
425
|
+
# features: list[Feature] = []
|
426
|
+
# file_indices: list[slice] = []
|
427
|
+
# total_frames = 0
|
428
|
+
# for mixid in mixids:
|
429
|
+
# current_feature, _ = mixdb.mixture_ft(mixid)
|
430
|
+
# current_frames = current_feature.shape[0]
|
431
|
+
# features.append(current_feature)
|
432
|
+
# file_indices.append(slice(total_frames, total_frames + current_frames))
|
433
|
+
# total_frames += current_frames
|
434
|
+
# feature = np.vstack([features[i] for i in range(len(features))])
|
435
|
+
#
|
436
|
+
# predict = pad_and_predict(feature=feature,
|
437
|
+
# model_name=model_name,
|
438
|
+
# model_metadata=model_metadata,
|
439
|
+
# frames_per_batch=frames_per_batch,
|
440
|
+
# batch_size=batch_size,
|
441
|
+
# timesteps=timesteps,
|
442
|
+
# reset=reset)
|
443
|
+
#
|
444
|
+
# # Write data to separate files
|
445
|
+
# for idx, mixid in enumerate(mixids):
|
446
|
+
# output_name = join(output_dir, mixdb.mixtures[mixid].name)
|
447
|
+
# with h5py.File(output_name, 'a') as f:
|
448
|
+
# if 'predict' in f:
|
449
|
+
# del f['predict']
|
450
|
+
# f.create_dataset('predict', data=predict[file_indices[idx]])
|
451
|
+
#
|
452
|
+
# logger.info(f'Saved results to {output_dir}')
|
453
|
+
#
|
454
|
+
|
455
|
+
# def pad_and_predict(feature: Feature,
|
456
|
+
# model_name: str,
|
457
|
+
# model_metadata: SonusAIMetaData,
|
458
|
+
# frames_per_batch: int,
|
459
|
+
# batch_size: int,
|
460
|
+
# timesteps: int,
|
461
|
+
# reset: bool) -> Predict:
|
462
|
+
# import onnxruntime as rt
|
463
|
+
# import numpy as np
|
464
|
+
#
|
465
|
+
# from sonusai.utils import reshape_inputs
|
466
|
+
# from sonusai.utils import reshape_outputs
|
467
|
+
#
|
468
|
+
# frames = feature.shape[0]
|
469
|
+
# padding = frames_per_batch - frames % frames_per_batch
|
470
|
+
# feature = np.pad(array=feature, pad_width=((0, padding), (0, 0), (0, 0)))
|
471
|
+
# feature, _ = reshape_inputs(feature=feature,
|
472
|
+
# batch_size=batch_size,
|
473
|
+
# timesteps=timesteps,
|
474
|
+
# flatten=model_metadata.flattened,
|
475
|
+
# add1ch=model_metadata.channel)
|
476
|
+
# sequences = feature.shape[0] // model_metadata.input_shape[0]
|
477
|
+
# feature = np.reshape(feature, [sequences, *model_metadata.input_shape])
|
478
|
+
#
|
479
|
+
# model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
|
480
|
+
# output_names = [n.name for n in model.get_outputs()]
|
481
|
+
# input_names = [n.name for n in model.get_inputs()]
|
482
|
+
#
|
483
|
+
# predict = []
|
484
|
+
# for sequence in range(sequences):
|
485
|
+
# predict.append(model.run(output_names, {input_names[0]: feature[sequence]}))
|
486
|
+
# if reset:
|
487
|
+
# model = rt.InferenceSession(model_name, providers=['CPUExecutionProvider'])
|
488
|
+
#
|
489
|
+
# predict_arr = np.vstack(predict)
|
490
|
+
# # Combine [sequences, batch_size, ...] into [frames, ...]
|
491
|
+
# predict_shape = predict_arr.shape
|
492
|
+
# predict_arr = np.reshape(predict_arr, [predict_shape[0] * predict_shape[1], *predict_shape[2:]])
|
493
|
+
# predict_arr, _ = reshape_outputs(predict=predict_arr, timesteps=timesteps)
|
494
|
+
# predict_arr = predict_arr[:frames, :]
|
495
|
+
#
|
496
|
+
# return predict_arr
|
233
497
|
|
234
498
|
|
235
499
|
if __name__ == '__main__':
|