sonusai 0.16.0__py3-none-any.whl → 0.17.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +1 -0
- sonusai/audiofe.py +157 -61
- sonusai/calc_metric_spenh-save.py +1334 -0
- sonusai/calc_metric_spenh.py +15 -8
- sonusai/genft.py +15 -6
- sonusai/genmix.py +14 -6
- sonusai/genmixdb.py +14 -6
- sonusai/gentcst.py +13 -6
- sonusai/lsdb.py +15 -5
- sonusai/mkmanifest.py +14 -6
- sonusai/mkwav.py +15 -6
- sonusai/onnx_predict-old.py +240 -0
- sonusai/onnx_predict-save.py +487 -0
- sonusai/onnx_predict.py +446 -182
- sonusai/ovino_predict.py +508 -0
- sonusai/ovino_query_devices.py +47 -0
- sonusai/plot.py +16 -6
- sonusai/post_spenh_targetf.py +13 -6
- sonusai/summarize_metric_spenh.py +71 -0
- sonusai/torchl_onnx-old.py +216 -0
- sonusai/tplot.py +14 -6
- sonusai/utils/onnx_utils.py +128 -39
- {sonusai-0.16.0.dist-info → sonusai-0.17.0.dist-info}/METADATA +1 -1
- {sonusai-0.16.0.dist-info → sonusai-0.17.0.dist-info}/RECORD +26 -19
- {sonusai-0.16.0.dist-info → sonusai-0.17.0.dist-info}/WHEEL +1 -1
- {sonusai-0.16.0.dist-info → sonusai-0.17.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,71 @@
|
|
1
|
+
"""sonusai summarize_metric_spenh
|
2
|
+
|
3
|
+
usage: summarize_metric_spenh [-hr] [-s SORT] LOC
|
4
|
+
|
5
|
+
options:
|
6
|
+
-h, --help
|
7
|
+
-s SORT, --sort SORT Sort by SORT column. [default: MIXID]
|
8
|
+
-r, --reverse Reverse sort order.
|
9
|
+
|
10
|
+
Summarize speech enhancement metrics results using data generated by SonusAI calc_metric_spenh.
|
11
|
+
|
12
|
+
Inputs:
|
13
|
+
LOC A SonusAI calc_metric_spenh results directory.
|
14
|
+
|
15
|
+
"""
|
16
|
+
import signal
|
17
|
+
|
18
|
+
|
19
|
+
def signal_handler(_sig, _frame):
|
20
|
+
import sys
|
21
|
+
|
22
|
+
from sonusai import logger
|
23
|
+
|
24
|
+
logger.info('Canceled due to keyboard interrupt')
|
25
|
+
sys.exit(1)
|
26
|
+
|
27
|
+
|
28
|
+
signal.signal(signal.SIGINT, signal_handler)
|
29
|
+
|
30
|
+
|
31
|
+
def summarize_metric_spenh(location: str, by: str = 'MIXID', reverse: bool = False) -> str:
|
32
|
+
import glob
|
33
|
+
|
34
|
+
import pandas as pd
|
35
|
+
|
36
|
+
files = sorted(glob.glob(location + '/*_metric_spenh.txt'))
|
37
|
+
need_header = True
|
38
|
+
header = ['MIXID']
|
39
|
+
data = []
|
40
|
+
for file in files:
|
41
|
+
with open(file, 'r') as f:
|
42
|
+
for i, line in enumerate(f):
|
43
|
+
if i == 1 and need_header:
|
44
|
+
need_header = False
|
45
|
+
header.extend(line.strip().split())
|
46
|
+
elif i == 2:
|
47
|
+
data.append(line.strip().split())
|
48
|
+
break
|
49
|
+
|
50
|
+
df = pd.DataFrame(data, columns=header)
|
51
|
+
df[header[0:-2]] = df[header[0:-2]].apply(pd.to_numeric, errors='coerce')
|
52
|
+
return df.sort_values(by=by, ascending=not reverse).to_string(index=False)
|
53
|
+
|
54
|
+
|
55
|
+
def main():
|
56
|
+
from docopt import docopt
|
57
|
+
|
58
|
+
import sonusai
|
59
|
+
from sonusai.utils import trim_docstring
|
60
|
+
|
61
|
+
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
62
|
+
|
63
|
+
by = args['--sort']
|
64
|
+
reverse = args['--reverse']
|
65
|
+
location = args['LOC']
|
66
|
+
|
67
|
+
print(summarize_metric_spenh(location, by, reverse))
|
68
|
+
|
69
|
+
|
70
|
+
if __name__ == '__main__':
|
71
|
+
main()
|
@@ -0,0 +1,216 @@
|
|
1
|
+
"""sonusai torchl_onnx
|
2
|
+
|
3
|
+
usage: torchl_onnx [-hv] [-b BATCH] [-t TSTEPS] [-o OUTPUT] MODEL CKPT
|
4
|
+
|
5
|
+
options:
|
6
|
+
-h, --help
|
7
|
+
-v, --verbose Be verbose
|
8
|
+
-b BATCH, --batch BATCH Batch size [default: 1]
|
9
|
+
-t TSTEPS, --tsteps TSTEPS Timesteps [default: 1]
|
10
|
+
-o OUTPUT, --output OUTPUT Output directory.
|
11
|
+
|
12
|
+
Convert a trained Pytorch Lightning model to ONNX. The model is specified as an
|
13
|
+
sctl_*.py model file (sctl: sonusai custom torch lightning) and a checkpoint file
|
14
|
+
for loading weights.
|
15
|
+
|
16
|
+
Inputs:
|
17
|
+
MODEL SonusAI Python custom model file.
|
18
|
+
CKPT A Pytorch Lightning checkpoint file
|
19
|
+
BATCH Batch size used in onnx conversion, overrides value in model ckpt. Defaults to 1.
|
20
|
+
TSTEPS Timestep dimension size using in onnx conversion, overrides value in model ckpt if
|
21
|
+
the model has a timestep dimension. Else it is ignored.
|
22
|
+
|
23
|
+
Outputs:
|
24
|
+
OUTPUT/ A directory containing:
|
25
|
+
<CKPT>.onnx Model file with batch_size and timesteps equal to provided parameters
|
26
|
+
<CKPT>-b1.onnx Model file with batch_size=1 and if the timesteps dimension exists it
|
27
|
+
is set to 1 (useful for real-time inference applications)
|
28
|
+
torchl_onnx.log
|
29
|
+
|
30
|
+
Results are written into subdirectory <MODEL>-<TIMESTAMP> unless OUTPUT is specified.
|
31
|
+
|
32
|
+
"""
|
33
|
+
from sonusai import logger
|
34
|
+
|
35
|
+
|
36
|
+
def main() -> None:
|
37
|
+
from docopt import docopt
|
38
|
+
|
39
|
+
import sonusai
|
40
|
+
from sonusai.utils import trim_docstring
|
41
|
+
|
42
|
+
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
43
|
+
|
44
|
+
verbose = args['--verbose']
|
45
|
+
batch_size = args['--batch']
|
46
|
+
timesteps = args['--tsteps']
|
47
|
+
model_path = args['MODEL']
|
48
|
+
ckpt_path = args['CKPT']
|
49
|
+
output_dir = args['--output']
|
50
|
+
|
51
|
+
from os import makedirs
|
52
|
+
from os.path import basename, splitext
|
53
|
+
from sonusai.utils import import_keras_model
|
54
|
+
|
55
|
+
# Import model definition file first to check
|
56
|
+
model_base = basename(model_path)
|
57
|
+
model_root = splitext(model_base)[0]
|
58
|
+
logger.info(f'Importing model from {model_base}')
|
59
|
+
try:
|
60
|
+
litemodule = import_keras_model(model_path) # note works for pytorch lightning as well as keras
|
61
|
+
except Exception as e:
|
62
|
+
logger.exception(f'Error: could not import model from {model_path}: {e}')
|
63
|
+
raise SystemExit(1)
|
64
|
+
|
65
|
+
# Load checkpoint first to get hparams if available
|
66
|
+
from torch import load as load
|
67
|
+
ckpt_base = basename(ckpt_path)
|
68
|
+
ckpt_root = splitext(ckpt_base)[0]
|
69
|
+
logger.info(f'Loading checkpoint from {ckpt_base}')
|
70
|
+
try:
|
71
|
+
checkpoint = load(ckpt_path, map_location=lambda storage, loc: storage)
|
72
|
+
except Exception as e:
|
73
|
+
logger.exception(f'Error: could not load checkpoint from {ckpt_path}: {e}')
|
74
|
+
raise SystemExit(1)
|
75
|
+
|
76
|
+
from os.path import join, isdir, dirname, exists
|
77
|
+
from sonusai import create_file_handler
|
78
|
+
from sonusai import initial_log_messages
|
79
|
+
from sonusai import update_console_handler
|
80
|
+
from torch import randn
|
81
|
+
from sonusai.utils import create_ts_name
|
82
|
+
|
83
|
+
from sonusai.utils import create_ts_name
|
84
|
+
from torchinfo import summary
|
85
|
+
|
86
|
+
if batch_size is not None:
|
87
|
+
batch_size = int(batch_size)
|
88
|
+
if batch_size != 1:
|
89
|
+
batch_size = 1
|
90
|
+
logger.info(f'For now prediction only supports batch_size = 1, forcing it to 1 now')
|
91
|
+
|
92
|
+
if timesteps is not None:
|
93
|
+
timesteps = int(timesteps)
|
94
|
+
|
95
|
+
if output_dir is None:
|
96
|
+
output_dir = dirname(ckpt_path)
|
97
|
+
else:
|
98
|
+
if not isdir(output_dir):
|
99
|
+
makedirs(output_dir, exist_ok=True)
|
100
|
+
|
101
|
+
ofname = join(output_dir, ckpt_root + '.onnx')
|
102
|
+
# First try, then add date
|
103
|
+
if exists(ofname):
|
104
|
+
# add hour-min-sec if necessary
|
105
|
+
from datetime import datetime
|
106
|
+
ts = datetime.now()
|
107
|
+
ofname = join(output_dir, ckpt_root + '-' + ts.strftime('%Y%m%d') + '.onnx')
|
108
|
+
ofname_root = splitext(ofname)[0]
|
109
|
+
|
110
|
+
# Setup logging file
|
111
|
+
create_file_handler(ofname_root + '-onnx.log')
|
112
|
+
update_console_handler(verbose)
|
113
|
+
initial_log_messages('torchl_onnx')
|
114
|
+
logger.info(f'Imported model from {model_base}')
|
115
|
+
logger.info(f'Loaded checkpoint from {ckpt_base}')
|
116
|
+
|
117
|
+
if 'hyper_parameters' in checkpoint:
|
118
|
+
hparams = checkpoint['hyper_parameters']
|
119
|
+
logger.info(f'Found hyper-params on checkpoint named {checkpoint["hparams_name"]} '
|
120
|
+
f'with {len(hparams)} total hparams.')
|
121
|
+
if batch_size is not None and hparams['batch_size'] != batch_size:
|
122
|
+
if batch_size != 1:
|
123
|
+
batch_size = 1
|
124
|
+
logger.info(f'For now prediction only supports batch_size = 1, forcing it to 1 now')
|
125
|
+
logger.info(f'Overriding batch_size: default = {hparams["batch_size"]}; specified = {batch_size}.')
|
126
|
+
hparams["batch_size"] = batch_size
|
127
|
+
|
128
|
+
if timesteps is not None:
|
129
|
+
if hparams['timesteps'] == 0 and timesteps != 0:
|
130
|
+
logger.warning(f'Model does not contain timesteps; ignoring override.')
|
131
|
+
timesteps = 0
|
132
|
+
|
133
|
+
if hparams['timesteps'] != 0 and timesteps == 0:
|
134
|
+
logger.warning(f'Model contains timesteps; ignoring override of 0, using model default.')
|
135
|
+
timesteps = hparams['timesteps']
|
136
|
+
|
137
|
+
if hparams['timesteps'] != timesteps:
|
138
|
+
logger.info(f'Overriding timesteps: default = {hparams["timesteps"]}; specified = {timesteps}.')
|
139
|
+
hparams['timesteps'] = timesteps
|
140
|
+
|
141
|
+
logger.info(f'Building model with hparams and batch_size={batch_size}, timesteps={timesteps}')
|
142
|
+
try:
|
143
|
+
model = litemodule.MyHyperModel(**hparams) # use hparams
|
144
|
+
# litemodule.MyHyperModel.load_from_checkpoint(ckpt_name, **hparams)
|
145
|
+
except Exception as e:
|
146
|
+
logger.exception(f'Error: model build (MyHyperModel) in {model_base} failed: {e}')
|
147
|
+
raise SystemExit(1)
|
148
|
+
else:
|
149
|
+
logger.info(f'Warning: found checkpoint with no hyper-parameters, building model with defaults')
|
150
|
+
try:
|
151
|
+
tmp = litemodule.MyHyperModel() # use default hparams
|
152
|
+
except Exception as e:
|
153
|
+
logger.exception(f'Error: model build (MyHyperModel) in {model_base} failed: {e}')
|
154
|
+
raise SystemExit(1)
|
155
|
+
|
156
|
+
if batch_size is not None:
|
157
|
+
if tmp.batch_size != batch_size:
|
158
|
+
logger.info(f'Overriding batch_size: default = {tmp.batch_size}; specified = {batch_size}.')
|
159
|
+
else:
|
160
|
+
batch_size = tmp.batch_size # inherit
|
161
|
+
|
162
|
+
if timesteps is not None:
|
163
|
+
if tmp.timesteps == 0 and timesteps != 0:
|
164
|
+
logger.warning(f'Model does not contain timesteps; ignoring override.')
|
165
|
+
timesteps = 0
|
166
|
+
|
167
|
+
if tmp.timesteps != 0 and timesteps == 0:
|
168
|
+
logger.warning(f'Model contains timesteps; ignoring override.')
|
169
|
+
timesteps = tmp.timesteps
|
170
|
+
|
171
|
+
if tmp.timesteps != timesteps:
|
172
|
+
logger.info(f'Overriding timesteps: default = {tmp.timesteps}; specified = {timesteps}.')
|
173
|
+
else:
|
174
|
+
timesteps = tmp.timesteps
|
175
|
+
|
176
|
+
logger.info(f'Building model with default hparams and batch_size= {batch_size}, timesteps={timesteps}')
|
177
|
+
model = litemodule.MyHyperModel(timesteps=timesteps, batch_size=batch_size)
|
178
|
+
|
179
|
+
logger.info('')
|
180
|
+
# logger.info(summary(model))
|
181
|
+
# from lightning.pytorch import Trainer
|
182
|
+
# from lightning.pytorch.callbacks import ModelSummary
|
183
|
+
# trainer = Trainer(callbacks=[ModelSummary(max_depth=2)])
|
184
|
+
# logger.info(trainer.summarize())
|
185
|
+
logger.info('')
|
186
|
+
logger.info(f'feature {model.hparams.feature}')
|
187
|
+
logger.info(f'num_classes {model.num_classes}')
|
188
|
+
logger.info(f'batch_size {model.hparams.batch_size}')
|
189
|
+
logger.info(f'timesteps {model.hparams.timesteps}')
|
190
|
+
logger.info(f'flatten {model.flatten}')
|
191
|
+
logger.info(f'add1ch {model.add1ch}')
|
192
|
+
logger.info(f'truth_mutex {model.truth_mutex}')
|
193
|
+
logger.info(f'input_shape {model.input_shape}')
|
194
|
+
logger.info('')
|
195
|
+
logger.info(f'Loading weights from {ckpt_base}')
|
196
|
+
# model = model.load_from_checkpoint(ckpt_path) # weights only, has problems - needs investigation
|
197
|
+
model.load_state_dict(checkpoint["state_dict"])
|
198
|
+
model.eval()
|
199
|
+
insample_shape = model.input_shape
|
200
|
+
insample_shape.insert(0, batch_size)
|
201
|
+
input_sample = randn(insample_shape)
|
202
|
+
logger.info(f'Creating onnx model ...')
|
203
|
+
for m in model.modules():
|
204
|
+
if 'instancenorm' in m.__class__.__name__.lower():
|
205
|
+
logger.info(f'Forcing train=false for instancenorm instance {m}, {m.__class__.__name__.lower()}')
|
206
|
+
m.train(False)
|
207
|
+
# m.track_running_stats=True # has problems
|
208
|
+
model.to_onnx(file_path=ofname, input_sample=input_sample, export_params=True)
|
209
|
+
|
210
|
+
|
211
|
+
if __name__ == '__main__':
|
212
|
+
try:
|
213
|
+
main()
|
214
|
+
except KeyboardInterrupt:
|
215
|
+
logger.info('Canceled due to keyboard interrupt')
|
216
|
+
exit()
|
sonusai/tplot.py
CHANGED
@@ -41,7 +41,19 @@ options:
|
|
41
41
|
A multi-page plot TARGET-tplot.pdf or CONFIG-tplot.pdf is generated.
|
42
42
|
|
43
43
|
"""
|
44
|
-
|
44
|
+
import signal
|
45
|
+
|
46
|
+
|
47
|
+
def signal_handler(_sig, _frame):
|
48
|
+
import sys
|
49
|
+
|
50
|
+
from sonusai import logger
|
51
|
+
|
52
|
+
logger.info('Canceled due to keyboard interrupt')
|
53
|
+
sys.exit(1)
|
54
|
+
|
55
|
+
|
56
|
+
signal.signal(signal.SIGINT, signal_handler)
|
45
57
|
|
46
58
|
|
47
59
|
# TODO: re-work for modern mixdb API
|
@@ -328,8 +340,4 @@ def main() -> None:
|
|
328
340
|
|
329
341
|
|
330
342
|
if __name__ == '__main__':
|
331
|
-
|
332
|
-
main()
|
333
|
-
except KeyboardInterrupt:
|
334
|
-
logger.info('Canceled due to keyboard interrupt')
|
335
|
-
raise SystemExit(0)
|
343
|
+
main()
|
sonusai/utils/onnx_utils.py
CHANGED
@@ -1,6 +1,12 @@
|
|
1
1
|
from dataclasses import dataclass
|
2
2
|
|
3
|
+
from sonusai import logger
|
4
|
+
from typing import Any #List, Optional, Tuple
|
5
|
+
import onnxruntime as ort
|
3
6
|
from onnxruntime import InferenceSession
|
7
|
+
import onnx
|
8
|
+
from onnx import ValueInfoProto
|
9
|
+
from os.path import basename, splitext, exists, isfile
|
4
10
|
|
5
11
|
|
6
12
|
@dataclass(frozen=True)
|
@@ -14,52 +20,135 @@ class SonusAIMetaData:
|
|
14
20
|
feature: str
|
15
21
|
|
16
22
|
|
17
|
-
def
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
23
|
+
def get_and_check_inputs(model: onnx.ModelProto) -> tuple[list[ValueInfoProto], list[list[int] | str]]:
|
24
|
+
# ignore initializer inputs (only seen in older onnx < v1.5
|
25
|
+
initializer_names = [x.name for x in model.graph.initializer]
|
26
|
+
onnx_inputs = [ipt for ipt in model.graph.input if ipt.name not in initializer_names]
|
27
|
+
if len(onnx_inputs) != 1:
|
28
|
+
logger.warning(f'Warning: onnx model does not have 1 input, but {len(onnx_inputs)}')
|
29
|
+
#raise SystemExit(1)
|
30
|
+
|
31
|
+
inshapes = []
|
32
|
+
for inp in onnx_inputs: # iterate through inputs of the graph to find shapes
|
33
|
+
tensor_type = inp.type.tensor_type # get tensor type: 0, 1, 2,
|
34
|
+
if (tensor_type.HasField("shape")): # check if it has a shape:
|
35
|
+
tmpshape = []
|
36
|
+
for d in tensor_type.shape.dim: # iterate through dimensions of the shape
|
37
|
+
if (d.HasField("dim_value")): # known dimension, int value
|
38
|
+
tmpshape.append(d.dim_value)
|
39
|
+
elif (d.HasField("dim_param")): # dynamic dim with symbolic name of d.dim_param
|
40
|
+
tmpshape.append(0) # set size to 0
|
41
|
+
else: # unknown dimension with no name
|
42
|
+
tmpshape.append(0) # also set to 0
|
43
|
+
inshapes.append(tmpshape) # add as a list
|
44
|
+
else:
|
45
|
+
inshapes.append("unknown rank")
|
46
|
+
|
47
|
+
# This one-liner works only if input has type and shape, returns a list
|
48
|
+
#in0shape = [d.dim_value for d in onnx_inputs[0].type.tensor_type.shape.dim]
|
49
|
+
|
50
|
+
return onnx_inputs, inshapes
|
24
51
|
|
25
|
-
:param model: ONNX model
|
26
|
-
:param is_flattened: Model feature data is flattened
|
27
|
-
:param has_timestep: Model has timestep dimension
|
28
|
-
:param has_channel: Model has channel dimension
|
29
|
-
:param is_mutex: Model label output is mutually exclusive
|
30
|
-
:param feature: Model feature type
|
31
|
-
"""
|
32
|
-
is_flattened_flag = model.metadata_props.add()
|
33
|
-
is_flattened_flag.key = 'is_flattened'
|
34
|
-
is_flattened_flag.value = str(is_flattened)
|
35
52
|
|
36
|
-
|
37
|
-
|
38
|
-
|
53
|
+
def get_and_check_outputs(model: onnx.ModelProto) -> tuple[list[ValueInfoProto], list[list[int | Any] | str]]:
|
54
|
+
onnx_outputs = [opt for opt in model.graph.output]
|
55
|
+
if len(onnx_outputs) != 1:
|
56
|
+
logger.warning(f'Warning: onnx model does not have 1 output, but {len(onnx_outputs)}')
|
39
57
|
|
40
|
-
|
41
|
-
|
42
|
-
|
58
|
+
oshapes = []
|
59
|
+
for inp in onnx_outputs: # iterate through inputs of the graph to find shapes
|
60
|
+
tensor_type = inp.type.tensor_type # get tensor type: 0, 1, 2,
|
61
|
+
if (tensor_type.HasField("shape")): # check if it has a shape:
|
62
|
+
tmpshape = []
|
63
|
+
for d in tensor_type.shape.dim: # iterate through dimensions of the shape
|
64
|
+
if (d.HasField("dim_value")): # known dimension, int value
|
65
|
+
tmpshape.append(d.dim_value)
|
66
|
+
elif (d.HasField("dim_param")): # dynamic dim with symbolic name of d.dim_param
|
67
|
+
tmpshape.append(0) # set size to 0
|
68
|
+
else: # unknown dimension with no name
|
69
|
+
tmpshape.append(0) # also set to 0
|
70
|
+
oshapes.append(tmpshape) # add as a list
|
71
|
+
else:
|
72
|
+
oshapes.append("unknown rank")
|
43
73
|
|
44
|
-
|
45
|
-
is_mutex_flag.key = 'is_mutex'
|
46
|
-
is_mutex_flag.value = str(is_mutex)
|
74
|
+
return onnx_outputs, oshapes
|
47
75
|
|
48
|
-
|
49
|
-
|
50
|
-
|
76
|
+
|
77
|
+
def add_sonusai_metadata(model, hparams):
|
78
|
+
"""Add SonusAI hyper-parameter metadata to an ONNX model using key hparams
|
79
|
+
|
80
|
+
:param model: ONNX model
|
81
|
+
:hparams: dictionary of hyper-parameters, added
|
82
|
+
Note SonusAI conventions require models to have:
|
83
|
+
- feature: Model feature type
|
84
|
+
- is_flattened: Model input feature data is flattened (stride + bins combined)
|
85
|
+
- timesteps: Size of timestep dimension (0 for no dimension)
|
86
|
+
- add1ch: Model input has channel dimension
|
87
|
+
- truth_mutex: Model label output is mutually exclusive
|
88
|
+
"""
|
89
|
+
|
90
|
+
# Add hyper-parameters as metadata in onnx model under hparams key
|
91
|
+
assert eval(str(hparams)) == hparams # Note hparams should be a dict (i.e. extracted from checkpoint)
|
92
|
+
meta = model.metadata_props.add()
|
93
|
+
meta.key = "hparams"
|
94
|
+
meta.value = str(hparams)
|
51
95
|
|
52
96
|
return model
|
53
97
|
|
54
98
|
|
55
|
-
def get_sonusai_metadata(
|
56
|
-
"""Get SonusAI metadata from an ONNX
|
99
|
+
def get_sonusai_metadata(session: InferenceSession) -> SonusAIMetaData:
|
100
|
+
"""Get SonusAI hyper-parameter metadata from an ONNX Runtime session.
|
101
|
+
Returns dictionary hparams
|
57
102
|
"""
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
103
|
+
meta = session.get_modelmeta()
|
104
|
+
hparams = eval(meta.custom_metadata_map["hparams"])
|
105
|
+
|
106
|
+
# m = model.get_modelmeta().custom_metadata_map
|
107
|
+
# return SonusAIMetaData(input_shape=model.get_inputs()[0].shape,
|
108
|
+
# output_shape=model.get_outputs()[0].shape,
|
109
|
+
# flattened=m['is_flattened'] == 'True',
|
110
|
+
# timestep=m['has_timestep'] == 'True',
|
111
|
+
# channel=m['has_channel'] == 'True',
|
112
|
+
# mutex=m['is_mutex'] == 'True',
|
113
|
+
# feature=m['feature'])
|
114
|
+
|
115
|
+
return hparams
|
116
|
+
|
117
|
+
def load_ort_session(model_path, providers=['CPUExecutionProvider']):
|
118
|
+
if exists(model_path) and isfile(model_path):
|
119
|
+
model_basename = basename(model_path)
|
120
|
+
model_root = splitext(model_basename)[0]
|
121
|
+
logger.info(f'Importing model from {model_basename}')
|
122
|
+
try:
|
123
|
+
session = ort.InferenceSession(model_path, providers=providers)
|
124
|
+
options = ort.SessionOptions()
|
125
|
+
except Exception as e:
|
126
|
+
logger.exception(f'Error: could not load onnx model from {model_path}: {e}')
|
127
|
+
raise SystemExit(1)
|
128
|
+
else:
|
129
|
+
logger.exception(f'Error: model file does not exist: {model_path}')
|
130
|
+
raise SystemExit(1)
|
131
|
+
|
132
|
+
logger.info(f'Opened session with provider options: {session._provider_options}.')
|
133
|
+
try:
|
134
|
+
meta = session.get_modelmeta()
|
135
|
+
hparams = eval(meta.custom_metadata_map["hparams"])
|
136
|
+
logger.info(f'Sonusai hyper-parameter metadata was found in model with {len(hparams)} parameters, '
|
137
|
+
f'checking for required ones ...')
|
138
|
+
# Print to log here will fail if required parameters not available.
|
139
|
+
logger.info(f'feature {hparams["feature"]}')
|
140
|
+
logger.info(f'batch_size {hparams["batch_size"]}')
|
141
|
+
logger.info(f'timesteps {hparams["timesteps"]}')
|
142
|
+
logger.info(f'flatten, add1ch {hparams["flatten"]}, {hparams["add1ch"]}')
|
143
|
+
logger.info(f'truth_mutex {hparams["truth_mutex"]}')
|
144
|
+
except:
|
145
|
+
hparams = None
|
146
|
+
logger.warning(f'Warning: onnx model does not have required SonusAI hyper-parameters.')
|
147
|
+
|
148
|
+
inputs = session.get_inputs()
|
149
|
+
outputs = session.get_outputs()
|
150
|
+
|
151
|
+
#in_names = [n.name for n in session.get_inputs()]
|
152
|
+
#out_names = [n.name for n in session.get_outputs()]
|
153
|
+
|
154
|
+
return session, options, model_root, hparams, inputs, outputs
|
@@ -1,7 +1,8 @@
|
|
1
|
-
sonusai/__init__.py,sha256=
|
1
|
+
sonusai/__init__.py,sha256=vzTFfRB-NeO-Sm3puySDJOybk3ND_Oj6w0EejQPmH1U,2978
|
2
2
|
sonusai/aawscd_probwrite.py,sha256=GukR5owp_0A3DrqSl9fHWULYgclNft4D5OkHIwfxxkc,3698
|
3
|
-
sonusai/audiofe.py,sha256=
|
4
|
-
sonusai/calc_metric_spenh.py,sha256
|
3
|
+
sonusai/audiofe.py,sha256=yPtbxeRAzlnPcRESXKNVexvIm6fM4WVxEHjd0w9n5O0,12455
|
4
|
+
sonusai/calc_metric_spenh-save.py,sha256=-LR5BtAnYNYKav1B2ZsB7gGevidCsQ91yFEaH8Ycyr8,61765
|
5
|
+
sonusai/calc_metric_spenh.py,sha256=_92RWCyxAf7_S61L5oX6o4GuvuOCDKn7BPg7pB0r1kY,61836
|
5
6
|
sonusai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
7
|
sonusai/data/genmixdb.yml,sha256=-XSs_hUR6wHJVoTPmSewzXL7u61X-xmHY46lNPatxSE,1025
|
7
8
|
sonusai/data/speech_ma01_01.wav,sha256=PK0vMKg-NR6rPE3KouxHGF6PKXnJCr7AwjMqfu98LUA,76644
|
@@ -9,11 +10,11 @@ sonusai/data/whitenoise.wav,sha256=I2umov0m34y56F9IsIBi1XtE76ZeZaSKDf70cJRe3pI,1
|
|
9
10
|
sonusai/doc/__init__.py,sha256=rP5Hgn0Iys_xkuv4caxngdqehuU4zLZsiKuv8Nde67M,19
|
10
11
|
sonusai/doc/doc.py,sha256=3z210v6ZckuOlsGZ3ySQBdlCNmBp2M1ahqhqG_eUN58,22664
|
11
12
|
sonusai/doc.py,sha256=l8CaFgLI8mqx4tn0aXfxKqa2dy9GgC0zjYxZAkpmi1E,878
|
12
|
-
sonusai/genft.py,sha256=
|
13
|
-
sonusai/genmix.py,sha256=
|
14
|
-
sonusai/genmixdb.py,sha256=
|
15
|
-
sonusai/gentcst.py,sha256=
|
16
|
-
sonusai/lsdb.py,sha256=
|
13
|
+
sonusai/genft.py,sha256=OzET3iTE-QhrUckzidfZvCDXZlAxIF5Xe5NEf856Vvk,5662
|
14
|
+
sonusai/genmix.py,sha256=TU5aTebGHsbfwsRbynYbegGBelSma9khuQkDk0dFE3I,7075
|
15
|
+
sonusai/genmixdb.py,sha256=M67Y_SEysgHfTmHHOdOjxdpuryTMDNgbDteCzR1uLk8,19669
|
16
|
+
sonusai/gentcst.py,sha256=W1ZO3xs7CoZkFcvOTH-FLJOIA4I7Wzb0HVRC3hGGSaM,20223
|
17
|
+
sonusai/lsdb.py,sha256=fMRqPlAu4B-4MsTXX-NaWXYyJ_dAOJlS-LrvQPQQsXg,6028
|
17
18
|
sonusai/main.py,sha256=GC-pQrSqx9tWwIcmEo6V9SraEv5KskBLS_W_wz-f2ZM,2509
|
18
19
|
sonusai/metrics/__init__.py,sha256=56itZW3S1I7ZYvbxPmFIVPAh1AIJZdljByz1uCrHqFE,635
|
19
20
|
sonusai/metrics/calc_class_weights.py,sha256=dyY7daEIf5Ms5tfTf6wF0fkx_GnMADHOZR_rtsfGoVM,3933
|
@@ -60,14 +61,20 @@ sonusai/mixture/truth_functions/file.py,sha256=jOJuC_3y9BH6GGOp9eKcbVrHLVRzUA80B
|
|
60
61
|
sonusai/mixture/truth_functions/phoneme.py,sha256=stYdlPuNytQK_LLT61OJLfYSqKd-sDjQZdtJKGzt5wA,479
|
61
62
|
sonusai/mixture/truth_functions/sed.py,sha256=8cHjEFjZaH_0hIOHhPmj4AJz2GpEADM6Ys2x4NoiWSY,2469
|
62
63
|
sonusai/mixture/truth_functions/target.py,sha256=KAsjugDRooOA5BRcHVAbZRgV7l8S5CFg7CZ0XtKZaQ0,5764
|
63
|
-
sonusai/mkmanifest.py,sha256=
|
64
|
-
sonusai/mkwav.py,sha256=
|
65
|
-
sonusai/onnx_predict.py,sha256=Bz_pR28oAZBarNajlKwyzBxmW7ktum77SmxDN2onKPM,9060
|
66
|
-
sonusai/
|
67
|
-
sonusai/
|
64
|
+
sonusai/mkmanifest.py,sha256=7lfK7YOdgAEP_Lxrf-YDxZ5iLH9MJuaOltBVpav2M9M,8705
|
65
|
+
sonusai/mkwav.py,sha256=kLfC2ZuF-t8P97nqYw2falTZpymxAeXv0YTJCe6nK10,5356
|
66
|
+
sonusai/onnx_predict-old.py,sha256=Bz_pR28oAZBarNajlKwyzBxmW7ktum77SmxDN2onKPM,9060
|
67
|
+
sonusai/onnx_predict-save.py,sha256=ewiV5-HcW5zcDWuIF9xEbdBwbdL8vNWfu2_kaur5jAo,22354
|
68
|
+
sonusai/onnx_predict.py,sha256=-ETvGH7fHXjnmY-c2r_8gOHEX1VSLwdSToZ6gSBP3_w,23021
|
69
|
+
sonusai/ovino_predict.py,sha256=QtWY_YEdaeqJL5yikfeeDwhbBtjXYjGNtR9310PGGSc,21830
|
70
|
+
sonusai/ovino_query_devices.py,sha256=XkXdOlZldI0MfG6nXZfTq8OgECaY8gAzQl36sJRuaIU,1584
|
71
|
+
sonusai/plot.py,sha256=ERkmxMM3qjcCDm4LGDQY4fRAncCYAzP7uW8iZ7_brcg,17105
|
72
|
+
sonusai/post_spenh_targetf.py,sha256=xOz5T6WZuyTHmfbtILIY9skgH064Wvi2GF2Bo5L3YMU,4998
|
68
73
|
sonusai/queries/__init__.py,sha256=oKY5JeqZ4Cz7DwCwPc1_ydB8bUs6KaMcWFp_w02TjOs,255
|
69
74
|
sonusai/queries/queries.py,sha256=FNMUKnoY_Ya9S5sNhsB8ppwy0B7V55ilbbjhQRv_UN8,7552
|
70
|
-
sonusai/
|
75
|
+
sonusai/summarize_metric_spenh.py,sha256=OiZe_bhCq5esXNhsOkHDD7g4ssYrpENDHvDVoPzV9iw,1822
|
76
|
+
sonusai/torchl_onnx-old.py,sha256=5JYow3XpBaUdtuyAW0mOZyCKL_4FrHvEekYBRdDT6KA,8967
|
77
|
+
sonusai/tplot.py,sha256=85T6OPZfxVegHBiSuilFpdgCNMEE0VKAuciNy4rCY5Y,14544
|
71
78
|
sonusai/utils/__init__.py,sha256=TCXlcW8W0Up2f5ciSgz3DabvH1MxrrWD0LK6pQTJkeA,2215
|
72
79
|
sonusai/utils/asl_p56.py,sha256=-bvQpd-jRQVURbkZJpRoyEAq6gTv9Rc3oFDbh5_lcjY,3861
|
73
80
|
sonusai/utils/asr.py,sha256=6y6VYJizHpuQ3MgKbEQ4t2gofO-MW6Ez23oAd6d23IE,2920
|
@@ -96,7 +103,7 @@ sonusai/utils/human_readable_size.py,sha256=SjYT0fUlpbfCzCXHo6csir-VMwqfs5ogr-fg
|
|
96
103
|
sonusai/utils/max_text_width.py,sha256=pxiJMwb_zlkNntexgo7S6lAuF7NLLZvFdOCkxdsQJVY,315
|
97
104
|
sonusai/utils/model_utils.py,sha256=lt2KOGJqsinG71W0i3U29UXFO-47GMAlEabsf2um7bA,862
|
98
105
|
sonusai/utils/numeric_conversion.py,sha256=GRO_2Fba8CcxcFY7bEXKOEUEUX6neA-VN__Bxi1ULsE,340
|
99
|
-
sonusai/utils/onnx_utils.py,sha256=
|
106
|
+
sonusai/utils/onnx_utils.py,sha256=L0BcwF0or1UwxYWzvWNWNJHKvG_oEmI2AxYh4msp2vc,6862
|
100
107
|
sonusai/utils/parallel.py,sha256=bxedjCzBv9oxzU7NajRr6mOKmkCWr2P7FWAI0p2p9N8,1981
|
101
108
|
sonusai/utils/print_mixture_details.py,sha256=BzYM4-wHHNa6zxPzBMUJxwKt0gKHmvbwdd7Yp0w15Yk,3017
|
102
109
|
sonusai/utils/ranges.py,sha256=NPBZOVzMb95GTOIxltVO-wSzgcXqZ14wbdV46JDLKrw,1222
|
@@ -109,7 +116,7 @@ sonusai/utils/stratified_shuffle_split.py,sha256=rJNXvBp-GxoKzH3OpL7k0ANSu5xMP2z
|
|
109
116
|
sonusai/utils/wave.py,sha256=O4ZXkZ6wjrKGa99wBCdFd8G6bp91MXXDnmGihpaEMh0,856
|
110
117
|
sonusai/utils/yes_or_no.py,sha256=eMLXBVH0cEahiXY4W2KNORmwNQ-ba10eRtldh0y4NYg,263
|
111
118
|
sonusai/vars.py,sha256=m2AefF0m5bXWGXpJj8Pi42zWL2ydeEj7bkak3GrtMyM,940
|
112
|
-
sonusai-0.
|
113
|
-
sonusai-0.
|
114
|
-
sonusai-0.
|
115
|
-
sonusai-0.
|
119
|
+
sonusai-0.17.0.dist-info/METADATA,sha256=wn6MaT5JXlVzp45_huLIAnJeH04kQZm8r0_vWiDX3LU,2443
|
120
|
+
sonusai-0.17.0.dist-info/WHEEL,sha256=FMvqSimYX_P7y0a7UY-_Mc83r5zkBZsCYPm7Lr0Bsq4,88
|
121
|
+
sonusai-0.17.0.dist-info/entry_points.txt,sha256=zMNjEphEPO6B3cD1GNpit7z-yA9tUU5-j3W2v-UWstU,92
|
122
|
+
sonusai-0.17.0.dist-info/RECORD,,
|
File without changes
|