onnxruntime_extensions 0.14.0__cp313-cp313-macosx_11_0_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime_extensions/__init__.py +82 -0
- onnxruntime_extensions/_cuops.py +564 -0
- onnxruntime_extensions/_extensions_pydll.cpython-313-darwin.so +0 -0
- onnxruntime_extensions/_extensions_pydll.pyi +45 -0
- onnxruntime_extensions/_hf_cvt.py +331 -0
- onnxruntime_extensions/_ocos.py +133 -0
- onnxruntime_extensions/_ortapi2.py +274 -0
- onnxruntime_extensions/_torch_cvt.py +231 -0
- onnxruntime_extensions/_version.py +2 -0
- onnxruntime_extensions/cmd.py +66 -0
- onnxruntime_extensions/cvt.py +306 -0
- onnxruntime_extensions/onnxprocess/__init__.py +12 -0
- onnxruntime_extensions/onnxprocess/_builder.py +53 -0
- onnxruntime_extensions/onnxprocess/_onnx_ops.py +1507 -0
- onnxruntime_extensions/onnxprocess/_session.py +355 -0
- onnxruntime_extensions/onnxprocess/_tensor.py +628 -0
- onnxruntime_extensions/onnxprocess/torch_wrapper.py +31 -0
- onnxruntime_extensions/pnp/__init__.py +13 -0
- onnxruntime_extensions/pnp/_base.py +124 -0
- onnxruntime_extensions/pnp/_imagenet.py +65 -0
- onnxruntime_extensions/pnp/_nlp.py +148 -0
- onnxruntime_extensions/pnp/_onnx_ops.py +1544 -0
- onnxruntime_extensions/pnp/_torchext.py +310 -0
- onnxruntime_extensions/pnp/_unifier.py +45 -0
- onnxruntime_extensions/pnp/_utils.py +302 -0
- onnxruntime_extensions/pp_api.py +83 -0
- onnxruntime_extensions/tools/__init__.py +0 -0
- onnxruntime_extensions/tools/add_HuggingFace_CLIPImageProcessor_to_model.py +171 -0
- onnxruntime_extensions/tools/add_pre_post_processing_to_model.py +535 -0
- onnxruntime_extensions/tools/pre_post_processing/__init__.py +4 -0
- onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py +395 -0
- onnxruntime_extensions/tools/pre_post_processing/step.py +227 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/__init__.py +6 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/general.py +366 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/nlp.py +344 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/vision.py +1157 -0
- onnxruntime_extensions/tools/pre_post_processing/utils.py +139 -0
- onnxruntime_extensions/util.py +186 -0
- onnxruntime_extensions-0.14.0.dist-info/LICENSE +21 -0
- onnxruntime_extensions-0.14.0.dist-info/METADATA +102 -0
- onnxruntime_extensions-0.14.0.dist-info/RECORD +43 -0
- onnxruntime_extensions-0.14.0.dist-info/WHEEL +6 -0
- onnxruntime_extensions-0.14.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,274 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
3
|
+
# license information.
|
|
4
|
+
###############################################################################
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
_ortapi2.py: ONNXRuntime-Extensions Python API
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import copy
|
|
11
|
+
import numpy as np
|
|
12
|
+
from ._ocos import default_opset_domain, get_library_path, Opdef
|
|
13
|
+
from ._cuops import onnx, onnx_proto, SingleOpGraph
|
|
14
|
+
|
|
15
|
+
_ort_check_passed = False
|
|
16
|
+
try:
|
|
17
|
+
from packaging import version as _ver
|
|
18
|
+
import onnxruntime as _ort
|
|
19
|
+
|
|
20
|
+
if _ver.parse(_ort.__version__) >= _ver.parse("1.10.0"):
|
|
21
|
+
_ort_check_passed = True
|
|
22
|
+
except ImportError:
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
if not _ort_check_passed:
|
|
26
|
+
raise RuntimeError("please install ONNXRuntime/ONNXRuntime-GPU >= 1.10.0")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _ensure_opset_domain(model):
|
|
30
|
+
op_domain_name = default_opset_domain()
|
|
31
|
+
domain_missing = True
|
|
32
|
+
for oi_ in model.opset_import:
|
|
33
|
+
if oi_.domain == op_domain_name:
|
|
34
|
+
domain_missing = False
|
|
35
|
+
|
|
36
|
+
if domain_missing:
|
|
37
|
+
model.opset_import.extend([onnx.helper.make_operatorsetid(op_domain_name, 1)])
|
|
38
|
+
|
|
39
|
+
return model
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def hook_model_op(model, node_name, hook_func, input_types):
|
|
43
|
+
"""
|
|
44
|
+
Add a hook function node in the ONNX Model, which could be used for the model diagnosis.
|
|
45
|
+
:param model: The ONNX model loaded as ModelProto
|
|
46
|
+
:param node_name: The node name where the hook will be installed
|
|
47
|
+
:param hook_func: The hook function, callback on the model inference
|
|
48
|
+
:param input_types: The input types as a list
|
|
49
|
+
:return: The ONNX model with the hook installed
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
# onnx.shape_inference is very unstable, useless.
|
|
53
|
+
# hkd_model = shape_inference.infer_shapes(model)
|
|
54
|
+
hkd_model = model
|
|
55
|
+
|
|
56
|
+
n_idx = 0
|
|
57
|
+
hnode, nnode = (None, None)
|
|
58
|
+
nodes = list(hkd_model.graph.node)
|
|
59
|
+
brkpt_name = node_name + "_hkd"
|
|
60
|
+
optype_name = "op_{}_{}".format(hook_func.__name__, node_name)
|
|
61
|
+
for n_ in nodes:
|
|
62
|
+
if n_.name == node_name:
|
|
63
|
+
input_names = list(n_.input)
|
|
64
|
+
brk_output_name = [i_ + "_hkd" for i_ in input_names]
|
|
65
|
+
hnode = onnx.helper.make_node(
|
|
66
|
+
optype_name, n_.input, brk_output_name, name=brkpt_name, domain=default_opset_domain()
|
|
67
|
+
)
|
|
68
|
+
nnode = n_
|
|
69
|
+
del nnode.input[:]
|
|
70
|
+
nnode.input.extend(brk_output_name)
|
|
71
|
+
break
|
|
72
|
+
n_idx += 1
|
|
73
|
+
|
|
74
|
+
if hnode is None:
|
|
75
|
+
raise ValueError("{} is not an operator node name".format(node_name))
|
|
76
|
+
|
|
77
|
+
repacked = nodes[:n_idx] + [hnode, nnode] + nodes[n_idx + 1 :]
|
|
78
|
+
del hkd_model.graph.node[:]
|
|
79
|
+
hkd_model.graph.node.extend(repacked)
|
|
80
|
+
|
|
81
|
+
Opdef.create(hook_func, op_type=optype_name, inputs=input_types, outputs=input_types)
|
|
82
|
+
return _ensure_opset_domain(hkd_model)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def expand_onnx_inputs(model, target_input, extra_nodes, new_inputs):
|
|
86
|
+
"""
|
|
87
|
+
Replace the existing inputs of a model with the new inputs, plus some extra nodes
|
|
88
|
+
:param model: The ONNX model loaded as ModelProto
|
|
89
|
+
:param target_input: The input name to be replaced
|
|
90
|
+
:param extra_nodes: The extra nodes to be added
|
|
91
|
+
:param new_inputs: The new input (type: ValueInfoProto) sequence
|
|
92
|
+
:return: The ONNX model after modification
|
|
93
|
+
"""
|
|
94
|
+
graph = model.graph
|
|
95
|
+
new_inputs = [n for n in graph.input if n.name != target_input] + new_inputs
|
|
96
|
+
new_nodes = list(model.graph.node) + extra_nodes
|
|
97
|
+
new_graph = onnx.helper.make_graph(new_nodes, graph.name, new_inputs, list(graph.output), list(graph.initializer))
|
|
98
|
+
|
|
99
|
+
new_model = copy.deepcopy(model)
|
|
100
|
+
new_model.graph.CopyFrom(new_graph)
|
|
101
|
+
|
|
102
|
+
return _ensure_opset_domain(new_model)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def get_opset_version_from_ort():
|
|
106
|
+
_ORT_OPSET_SUPPORT_TABLE = {
|
|
107
|
+
"1.5": 11,
|
|
108
|
+
"1.6": 12,
|
|
109
|
+
"1.7": 13,
|
|
110
|
+
"1.8": 14,
|
|
111
|
+
"1.9": 15,
|
|
112
|
+
"1.10": 15,
|
|
113
|
+
"1.11": 16,
|
|
114
|
+
"1.12": 17,
|
|
115
|
+
"1.13": 17,
|
|
116
|
+
"1.14": 18,
|
|
117
|
+
"1.15": 18,
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
ort_ver_string = ".".join(_ort.__version__.split(".")[0:2])
|
|
121
|
+
max_ver = max(_ORT_OPSET_SUPPORT_TABLE, key=_ORT_OPSET_SUPPORT_TABLE.get)
|
|
122
|
+
if ort_ver_string > max_ver:
|
|
123
|
+
ort_ver_string = max_ver
|
|
124
|
+
return _ORT_OPSET_SUPPORT_TABLE.get(ort_ver_string, 11)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(), extra_opset_version=1):
|
|
128
|
+
if opset_version == 0:
|
|
129
|
+
opset_version = get_opset_version_from_ort()
|
|
130
|
+
fn_mm = (
|
|
131
|
+
onnx.helper.make_model_gen_version if hasattr(onnx.helper, "make_model_gen_version") else onnx.helper.make_model
|
|
132
|
+
)
|
|
133
|
+
model = fn_mm(graph, opset_imports=[onnx.helper.make_operatorsetid("ai.onnx", opset_version)])
|
|
134
|
+
model.opset_import.extend([onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)])
|
|
135
|
+
return model
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
class OrtPyFunction:
|
|
139
|
+
"""
|
|
140
|
+
OrtPyFunction is a convenience class that serves as a wrapper around the ONNXRuntime InferenceSession,
|
|
141
|
+
equipped with registered onnxruntime-extensions. This allows execution of an ONNX model as if it were a
|
|
142
|
+
standard Python function. The order of the function arguments correlates directly with
|
|
143
|
+
the sequence of the input/output in the ONNX graph.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
def get_ort_session_options(self):
|
|
147
|
+
so = _ort.SessionOptions()
|
|
148
|
+
for k, v in self.extra_session_options.items():
|
|
149
|
+
so.__setattr__(k, v)
|
|
150
|
+
so.register_custom_ops_library(get_library_path())
|
|
151
|
+
return so
|
|
152
|
+
|
|
153
|
+
def __init__(self, path_or_model=None, cpu_only=None):
|
|
154
|
+
self._onnx_model = None
|
|
155
|
+
self.ort_session = None
|
|
156
|
+
self.default_inputs = {}
|
|
157
|
+
self.execution_providers = ["CPUExecutionProvider"]
|
|
158
|
+
if not cpu_only:
|
|
159
|
+
if _ort.get_device() == "GPU":
|
|
160
|
+
self.execution_providers = ["CUDAExecutionProvider"]
|
|
161
|
+
self.extra_session_options = {}
|
|
162
|
+
mpath = None
|
|
163
|
+
if isinstance(path_or_model, str):
|
|
164
|
+
oxml = onnx.load_model(path_or_model)
|
|
165
|
+
mpath = path_or_model
|
|
166
|
+
else:
|
|
167
|
+
oxml = path_or_model
|
|
168
|
+
if path_or_model is not None:
|
|
169
|
+
self._bind(oxml, mpath)
|
|
170
|
+
|
|
171
|
+
def create_from_customop(self, op_type, *args, **kwargs):
|
|
172
|
+
graph = SingleOpGraph.build_graph(op_type, *args, **kwargs)
|
|
173
|
+
self._bind(make_onnx_model(graph))
|
|
174
|
+
return self
|
|
175
|
+
|
|
176
|
+
def add_default_input(self, **kwargs):
|
|
177
|
+
inputs = {
|
|
178
|
+
ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else np.asarray(list(val_), dtype=np.uint8)
|
|
179
|
+
for ky_, val_ in kwargs.items()
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
self.default_inputs.update(inputs)
|
|
183
|
+
|
|
184
|
+
@property
|
|
185
|
+
def onnx_model(self):
|
|
186
|
+
assert self._oxml is not None, "No onnx model attached yet."
|
|
187
|
+
return self._oxml
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def input_names(self):
|
|
191
|
+
return [vi_.name for vi_ in self.onnx_model.graph.input]
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def output_names(self):
|
|
195
|
+
return [vi_.name for vi_ in self.onnx_model.graph.output]
|
|
196
|
+
|
|
197
|
+
def _bind(self, oxml, model_path=None):
|
|
198
|
+
self.inputs = list(oxml.graph.input)
|
|
199
|
+
self.outputs = list(oxml.graph.output)
|
|
200
|
+
self._oxml = oxml
|
|
201
|
+
if model_path is not None:
|
|
202
|
+
self.ort_session = _ort.InferenceSession(
|
|
203
|
+
model_path, self.get_ort_session_options(), self.execution_providers
|
|
204
|
+
)
|
|
205
|
+
return self
|
|
206
|
+
|
|
207
|
+
def _ensure_ort_session(self):
|
|
208
|
+
if self.ort_session is None:
|
|
209
|
+
sess = _ort.InferenceSession(
|
|
210
|
+
self.onnx_model.SerializeToString(), self.get_ort_session_options(), self.execution_providers
|
|
211
|
+
)
|
|
212
|
+
self.ort_session = sess
|
|
213
|
+
|
|
214
|
+
return self.ort_session
|
|
215
|
+
|
|
216
|
+
@staticmethod
|
|
217
|
+
def _get_kwarg_device(kwargs):
|
|
218
|
+
cpuonly = kwargs.get("cpu_only", None)
|
|
219
|
+
if cpuonly is not None:
|
|
220
|
+
del kwargs["cpu_only"]
|
|
221
|
+
return cpuonly
|
|
222
|
+
|
|
223
|
+
@classmethod
|
|
224
|
+
def from_customop(cls, op_type, *args, **kwargs):
|
|
225
|
+
return cls(cpu_only=cls._get_kwarg_device(kwargs)).create_from_customop(op_type, *args, **kwargs)
|
|
226
|
+
|
|
227
|
+
@classmethod
|
|
228
|
+
def from_model(cls, path_or_model, *args, **kwargs):
|
|
229
|
+
fn = cls(path_or_model, cls._get_kwarg_device(kwargs))
|
|
230
|
+
return fn
|
|
231
|
+
|
|
232
|
+
def _argument_map(self, *args, **kwargs):
|
|
233
|
+
idx = 0
|
|
234
|
+
feed = {}
|
|
235
|
+
for i_ in self.inputs:
|
|
236
|
+
if i_.name in self.default_inputs:
|
|
237
|
+
feed[i_.name] = self.default_inputs[i_.name]
|
|
238
|
+
continue
|
|
239
|
+
|
|
240
|
+
x = args[idx]
|
|
241
|
+
ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x
|
|
242
|
+
# numpy by default is int32 in some platforms, sometimes it is int64.
|
|
243
|
+
feed[i_.name] = (
|
|
244
|
+
ts_x.astype(np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
|
|
245
|
+
)
|
|
246
|
+
idx += 1
|
|
247
|
+
|
|
248
|
+
feed.update(kwargs)
|
|
249
|
+
return feed
|
|
250
|
+
|
|
251
|
+
def __call__(self, *args, **kwargs):
|
|
252
|
+
self._ensure_ort_session()
|
|
253
|
+
outputs = self.ort_session.run(None, self._argument_map(*args, **kwargs))
|
|
254
|
+
return outputs[0] if len(outputs) == 1 else tuple(outputs)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def ort_inference(model, *args, cpu_only=True, **kwargs):
|
|
258
|
+
"""
|
|
259
|
+
Run an ONNX model with ORT where args are inputs and return values are outputs.
|
|
260
|
+
"""
|
|
261
|
+
return OrtPyFunction(model, cpu_only=cpu_only)(*args, **kwargs)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
def optimize_model(model_or_file, output_file):
|
|
265
|
+
sess_options = OrtPyFunction().get_ort_session_options()
|
|
266
|
+
sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
|
267
|
+
sess_options.optimized_model_filepath = output_file
|
|
268
|
+
_ort.InferenceSession(
|
|
269
|
+
model_or_file if isinstance(model_or_file, str) else model_or_file.SerializeToString(), sess_options
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
ONNXRuntimeError = _ort.capi.onnxruntime_pybind11_state.Fail
|
|
274
|
+
ONNXRuntimeException = _ort.capi.onnxruntime_pybind11_state.RuntimeException
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
3
|
+
# license information.
|
|
4
|
+
###############################################################################
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
_torch_cvt.py: Data processing graph converted from PyTorch
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import io
|
|
11
|
+
import onnx
|
|
12
|
+
import torch
|
|
13
|
+
import numpy as np
|
|
14
|
+
|
|
15
|
+
from onnx import numpy_helper
|
|
16
|
+
|
|
17
|
+
from ._ortapi2 import make_onnx_model
|
|
18
|
+
from ._cuops import SingleOpGraph
|
|
19
|
+
from ._hf_cvt import HFTokenizerConverter
|
|
20
|
+
from .util import remove_unused_initializers, mel_filterbank
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class _WhisperHParams:
|
|
24
|
+
SAMPLE_RATE = 16000
|
|
25
|
+
N_FFT = 400
|
|
26
|
+
N_MELS = 80
|
|
27
|
+
HOP_LENGTH = 160
|
|
28
|
+
CHUNK_LENGTH = 30
|
|
29
|
+
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
|
|
30
|
+
N_FRAMES = N_SAMPLES // HOP_LENGTH
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class CustomOpStftNorm(torch.autograd.Function):
|
|
34
|
+
@staticmethod
|
|
35
|
+
def symbolic(g, self, n_fft, hop_length, window):
|
|
36
|
+
t_n_fft = g.op('Constant', value_t=torch.tensor(
|
|
37
|
+
n_fft, dtype=torch.int64))
|
|
38
|
+
t_hop_length = g.op('Constant', value_t=torch.tensor(
|
|
39
|
+
hop_length, dtype=torch.int64))
|
|
40
|
+
t_frame_size = g.op(
|
|
41
|
+
'Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
|
|
42
|
+
return g.op("ai.onnx.contrib::StftNorm", self, t_n_fft, t_hop_length, window, t_frame_size)
|
|
43
|
+
|
|
44
|
+
@staticmethod
|
|
45
|
+
def forward(ctx, audio, n_fft, hop_length, window):
|
|
46
|
+
win_length = window.shape[0]
|
|
47
|
+
stft = torch.stft(audio, n_fft, hop_length, win_length, window,
|
|
48
|
+
center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
|
|
49
|
+
return stft.abs() ** 2
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class WhisperPrePipeline(torch.nn.Module):
|
|
53
|
+
def __init__(self, sr=_WhisperHParams.SAMPLE_RATE, n_fft=_WhisperHParams.N_FFT,
|
|
54
|
+
hop_length=_WhisperHParams.HOP_LENGTH, n_mels=_WhisperHParams.N_MELS,
|
|
55
|
+
n_samples=_WhisperHParams.N_SAMPLES):
|
|
56
|
+
super().__init__()
|
|
57
|
+
self.n_samples = n_samples
|
|
58
|
+
self.hop_length = hop_length
|
|
59
|
+
self.n_fft = n_fft
|
|
60
|
+
self.window = torch.hann_window(n_fft)
|
|
61
|
+
self.mel_filters = torch.from_numpy(
|
|
62
|
+
mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels))
|
|
63
|
+
|
|
64
|
+
def forward(self, audio_pcm: torch.Tensor):
|
|
65
|
+
stft_norm = CustomOpStftNorm.apply(audio_pcm,
|
|
66
|
+
self.n_fft,
|
|
67
|
+
self.hop_length,
|
|
68
|
+
self.window)
|
|
69
|
+
magnitudes = stft_norm[:, :, :-1]
|
|
70
|
+
mel_spec = self.mel_filters @ magnitudes
|
|
71
|
+
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
|
72
|
+
spec_min = log_spec.max() - 8.0
|
|
73
|
+
log_spec = torch.maximum(log_spec, spec_min)
|
|
74
|
+
spec_shape = log_spec.shape
|
|
75
|
+
padding_spec = torch.ones(spec_shape[0],
|
|
76
|
+
spec_shape[1],
|
|
77
|
+
self.n_samples // self.hop_length -
|
|
78
|
+
spec_shape[2],
|
|
79
|
+
dtype=torch.float)
|
|
80
|
+
padding_spec *= spec_min
|
|
81
|
+
log_spec = torch.cat((log_spec, padding_spec), dim=2)
|
|
82
|
+
log_spec = (log_spec + 4.0) / 4.0
|
|
83
|
+
return log_spec
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _to_onnx_stft(onnx_model, n_fft):
|
|
87
|
+
"""Convert custom-op STFT-Norm to ONNX STFT"""
|
|
88
|
+
node_idx = 0
|
|
89
|
+
new_stft_nodes = []
|
|
90
|
+
stft_norm_node = None
|
|
91
|
+
for node in onnx_model.graph.node:
|
|
92
|
+
if node.op_type == "StftNorm":
|
|
93
|
+
stft_norm_node = node
|
|
94
|
+
break
|
|
95
|
+
node_idx += 1
|
|
96
|
+
|
|
97
|
+
if stft_norm_node is None:
|
|
98
|
+
raise RuntimeError("Cannot find STFTNorm node in the graph")
|
|
99
|
+
|
|
100
|
+
make_node = onnx.helper.make_node
|
|
101
|
+
replaced_nodes = [
|
|
102
|
+
make_node('Constant', inputs=[], outputs=['const_minus_1_output_0'], name='const_minus_1',
|
|
103
|
+
value=numpy_helper.from_array(np.array([-1], dtype='int64'))),
|
|
104
|
+
make_node('Constant', inputs=[], outputs=['const_14_output_0'], name='const_14',
|
|
105
|
+
value=numpy_helper.from_array(np.array([0,
|
|
106
|
+
n_fft // 2, 0,
|
|
107
|
+
n_fft // 2], dtype='int64'),
|
|
108
|
+
name='const_14')),
|
|
109
|
+
make_node('Pad',
|
|
110
|
+
inputs=[stft_norm_node.input[0], 'const_14_output_0'],
|
|
111
|
+
outputs=['pad_1_output_0'], mode='reflect'),
|
|
112
|
+
make_node('Unsqueeze',
|
|
113
|
+
inputs=['pad_1_output_0', 'const_minus_1_output_0'],
|
|
114
|
+
outputs=['unsqueeze_1_output_0'],
|
|
115
|
+
name='unsqueeze_1'),
|
|
116
|
+
make_node('STFT',
|
|
117
|
+
inputs=['unsqueeze_1_output_0', stft_norm_node.input[2],
|
|
118
|
+
stft_norm_node.input[3], stft_norm_node.input[4]],
|
|
119
|
+
outputs=['stft_output_0'], name='stft', onesided=1),
|
|
120
|
+
make_node('Transpose', inputs=['stft_output_0'], outputs=['transpose_1_output_0'], name='transpose_1',
|
|
121
|
+
perm=[0, 2, 1, 3]),
|
|
122
|
+
make_node('Constant', inputs=[], outputs=['const_17_output_0'], name='const_17',
|
|
123
|
+
value=numpy_helper.from_array(np.array([2], dtype='int64'), name='')),
|
|
124
|
+
make_node('Constant', inputs=[], outputs=['const_18_output_0'], name='const_18',
|
|
125
|
+
value=numpy_helper.from_array(np.array([0], dtype='int64'), name='')),
|
|
126
|
+
make_node('Constant', inputs=[], outputs=['const_20_output_0'], name='const_20',
|
|
127
|
+
value=numpy_helper.from_array(np.array([1], dtype='int64'), name='')),
|
|
128
|
+
make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_minus_1_output_0',
|
|
129
|
+
'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'],
|
|
130
|
+
name='slice_1'),
|
|
131
|
+
make_node('Constant', inputs=[], outputs=[
|
|
132
|
+
'const0_output_0'], name='const0', value_int=0),
|
|
133
|
+
make_node('Constant', inputs=[], outputs=[
|
|
134
|
+
'const1_output_0'], name='const1', value_int=1),
|
|
135
|
+
make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'],
|
|
136
|
+
name='gather_4', axis=3),
|
|
137
|
+
make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'],
|
|
138
|
+
name='gather_5', axis=3),
|
|
139
|
+
make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=[
|
|
140
|
+
'mul_output_0'], name='mul0'),
|
|
141
|
+
make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=[
|
|
142
|
+
'mul_1_output_0'], name='mul1'),
|
|
143
|
+
make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[
|
|
144
|
+
stft_norm_node.output[0]], name='add0'),
|
|
145
|
+
]
|
|
146
|
+
new_stft_nodes.extend(onnx_model.graph.node[:node_idx])
|
|
147
|
+
new_stft_nodes.extend(replaced_nodes)
|
|
148
|
+
new_stft_nodes.extend(onnx_model.graph.node[node_idx + 1:])
|
|
149
|
+
del onnx_model.graph.node[:]
|
|
150
|
+
onnx_model.graph.node.extend(new_stft_nodes)
|
|
151
|
+
onnx.checker.check_model(onnx_model)
|
|
152
|
+
return onnx_model
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _torch_export(*arg, **kwargs):
|
|
156
|
+
with io.BytesIO() as f:
|
|
157
|
+
torch.onnx.export(*arg, f, **kwargs)
|
|
158
|
+
return onnx.load_from_string(f.getvalue())
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class WhisperDataProcGraph:
|
|
162
|
+
def __init__(self, processor, **kwargs):
|
|
163
|
+
self.hf_processor = processor
|
|
164
|
+
_opset = kwargs.pop('opset', 17)
|
|
165
|
+
self.opset_version = _opset if _opset else 17
|
|
166
|
+
|
|
167
|
+
def pre_processing(self, **kwargs):
|
|
168
|
+
use_audio_decoder = kwargs.pop('USE_AUDIO_DECODER', True)
|
|
169
|
+
use_onnx_stft = kwargs.pop('USE_ONNX_STFT', True)
|
|
170
|
+
feature_extractor = self.hf_processor.feature_extractor
|
|
171
|
+
whisper_processing = WhisperPrePipeline(
|
|
172
|
+
feature_extractor.sampling_rate,
|
|
173
|
+
feature_extractor.n_fft,
|
|
174
|
+
feature_extractor.hop_length,
|
|
175
|
+
feature_extractor.feature_size,
|
|
176
|
+
feature_extractor.n_samples)
|
|
177
|
+
|
|
178
|
+
audio_pcm = torch.rand((1, 32000), dtype=torch.float32)
|
|
179
|
+
model_args = (audio_pcm,)
|
|
180
|
+
pre_model = _torch_export(
|
|
181
|
+
whisper_processing,
|
|
182
|
+
model_args,
|
|
183
|
+
input_names=["audio_pcm"],
|
|
184
|
+
output_names=["log_mel"],
|
|
185
|
+
do_constant_folding=True,
|
|
186
|
+
export_params=True,
|
|
187
|
+
opset_version=self.opset_version,
|
|
188
|
+
dynamic_axes={
|
|
189
|
+
"audio_pcm": {1: "sample_len"},
|
|
190
|
+
}
|
|
191
|
+
)
|
|
192
|
+
if use_onnx_stft:
|
|
193
|
+
pre_model = _to_onnx_stft(pre_model, feature_extractor.n_fft)
|
|
194
|
+
remove_unused_initializers(pre_model.graph)
|
|
195
|
+
|
|
196
|
+
pre_full = pre_model
|
|
197
|
+
if use_audio_decoder:
|
|
198
|
+
audecoder_g = SingleOpGraph.build_graph(
|
|
199
|
+
"AudioDecoder",
|
|
200
|
+
downsampling_rate=feature_extractor.sampling_rate,
|
|
201
|
+
stereo_to_mono=1)
|
|
202
|
+
audecoder_m = make_onnx_model(audecoder_g)
|
|
203
|
+
pre_full = onnx.compose.merge_models(
|
|
204
|
+
audecoder_m,
|
|
205
|
+
pre_model,
|
|
206
|
+
io_map=[("floatPCM", "audio_pcm")])
|
|
207
|
+
|
|
208
|
+
return pre_full
|
|
209
|
+
|
|
210
|
+
def post_processing(self, **kwargs):
|
|
211
|
+
skip_special_tokens = kwargs.get('skip_special_tokens', True)
|
|
212
|
+
g = SingleOpGraph.build_graph(
|
|
213
|
+
"BpeDecoder",
|
|
214
|
+
cvt=HFTokenizerConverter(self.hf_processor.tokenizer).bpe_decoder,
|
|
215
|
+
skip_special_tokens=skip_special_tokens)
|
|
216
|
+
|
|
217
|
+
bpenode = g.node[0]
|
|
218
|
+
bpenode.input[0] = "generated_ids"
|
|
219
|
+
nodes = [onnx.helper.make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64),
|
|
220
|
+
bpenode]
|
|
221
|
+
del g.node[:]
|
|
222
|
+
g.node.extend(nodes)
|
|
223
|
+
|
|
224
|
+
inputs = [onnx.helper.make_tensor_value_info(
|
|
225
|
+
"sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
|
|
226
|
+
del g.input[:]
|
|
227
|
+
g.input.extend(inputs)
|
|
228
|
+
g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(
|
|
229
|
+
onnx.TensorProto.STRING, ['N', 'text']))
|
|
230
|
+
|
|
231
|
+
return make_onnx_model(g, opset_version=self.opset_version)
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
3
|
+
# license information.
|
|
4
|
+
###############################################################################
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
cmd.py: cli commands for onnxruntime_extensions
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import os
|
|
11
|
+
import argparse
|
|
12
|
+
import onnx
|
|
13
|
+
import numpy
|
|
14
|
+
|
|
15
|
+
from onnx import onnx_pb, save_tensor, numpy_helper
|
|
16
|
+
from ._ortapi2 import OrtPyFunction
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ORTExtCommands:
|
|
20
|
+
def __init__(self, model='model.onnx', testdata_dir=None) -> None:
|
|
21
|
+
self._model = model
|
|
22
|
+
self._testdata_dir = testdata_dir
|
|
23
|
+
|
|
24
|
+
def run(self, *args):
|
|
25
|
+
"""
|
|
26
|
+
Run an onnx model with the arguments as its inputs
|
|
27
|
+
"""
|
|
28
|
+
op_func = OrtPyFunction.from_model(self._model)
|
|
29
|
+
np_args = [numpy.asarray(_x) for _x in args]
|
|
30
|
+
for _idx, _sch in enumerate(op_func.inputs):
|
|
31
|
+
if _sch.type.tensor_type.elem_type == onnx_pb.TensorProto.FLOAT:
|
|
32
|
+
np_args[_idx] = np_args[_idx].astype(numpy.float32)
|
|
33
|
+
|
|
34
|
+
print(op_func(*np_args))
|
|
35
|
+
if self._testdata_dir:
|
|
36
|
+
testdir = os.path.expanduser(self._testdata_dir)
|
|
37
|
+
target_dir = os.path.join(testdir, 'test_data_set_0')
|
|
38
|
+
os.makedirs(target_dir, exist_ok=True)
|
|
39
|
+
for _idx, _x in enumerate(np_args):
|
|
40
|
+
fn = os.path.join(target_dir, "input_{}.pb".format(_idx))
|
|
41
|
+
save_tensor(numpy_helper.from_array(_x, op_func.inputs[_idx].name), fn)
|
|
42
|
+
onnx.save_model(op_func.onnx_model, os.path.join(testdir, 'model.onnx'))
|
|
43
|
+
|
|
44
|
+
def selfcheck(self, *args):
|
|
45
|
+
print("The extensions loaded, status: OK.")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def main():
|
|
49
|
+
parser = argparse.ArgumentParser(description="ORT Extension commands")
|
|
50
|
+
parser.add_argument("command", choices=["run", "selfcheck"])
|
|
51
|
+
parser.add_argument("--model", default="model.onnx", help="Path to the ONNX model file")
|
|
52
|
+
parser.add_argument("--testdata-dir", help="Path to the test data directory")
|
|
53
|
+
parser.add_argument("args", nargs=argparse.REMAINDER, help="Additional arguments")
|
|
54
|
+
|
|
55
|
+
args = parser.parse_args()
|
|
56
|
+
|
|
57
|
+
ort_commands = ORTExtCommands(model=args.model, testdata_dir=args.testdata_dir)
|
|
58
|
+
|
|
59
|
+
if args.command == "run":
|
|
60
|
+
ort_commands.run(*args.args)
|
|
61
|
+
elif args.command == "selfcheck":
|
|
62
|
+
ort_commands.selfcheck(*args.args)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
if __name__ == '__main__':
|
|
66
|
+
main()
|