onnxruntime_extensions 0.14.0__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnxruntime_extensions/__init__.py +82 -0
  2. onnxruntime_extensions/_cuops.py +564 -0
  3. onnxruntime_extensions/_extensions_pydll.cpython-313-darwin.so +0 -0
  4. onnxruntime_extensions/_extensions_pydll.pyi +45 -0
  5. onnxruntime_extensions/_hf_cvt.py +331 -0
  6. onnxruntime_extensions/_ocos.py +133 -0
  7. onnxruntime_extensions/_ortapi2.py +274 -0
  8. onnxruntime_extensions/_torch_cvt.py +231 -0
  9. onnxruntime_extensions/_version.py +2 -0
  10. onnxruntime_extensions/cmd.py +66 -0
  11. onnxruntime_extensions/cvt.py +306 -0
  12. onnxruntime_extensions/onnxprocess/__init__.py +12 -0
  13. onnxruntime_extensions/onnxprocess/_builder.py +53 -0
  14. onnxruntime_extensions/onnxprocess/_onnx_ops.py +1507 -0
  15. onnxruntime_extensions/onnxprocess/_session.py +355 -0
  16. onnxruntime_extensions/onnxprocess/_tensor.py +628 -0
  17. onnxruntime_extensions/onnxprocess/torch_wrapper.py +31 -0
  18. onnxruntime_extensions/pnp/__init__.py +13 -0
  19. onnxruntime_extensions/pnp/_base.py +124 -0
  20. onnxruntime_extensions/pnp/_imagenet.py +65 -0
  21. onnxruntime_extensions/pnp/_nlp.py +148 -0
  22. onnxruntime_extensions/pnp/_onnx_ops.py +1544 -0
  23. onnxruntime_extensions/pnp/_torchext.py +310 -0
  24. onnxruntime_extensions/pnp/_unifier.py +45 -0
  25. onnxruntime_extensions/pnp/_utils.py +302 -0
  26. onnxruntime_extensions/pp_api.py +83 -0
  27. onnxruntime_extensions/tools/__init__.py +0 -0
  28. onnxruntime_extensions/tools/add_HuggingFace_CLIPImageProcessor_to_model.py +171 -0
  29. onnxruntime_extensions/tools/add_pre_post_processing_to_model.py +535 -0
  30. onnxruntime_extensions/tools/pre_post_processing/__init__.py +4 -0
  31. onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py +395 -0
  32. onnxruntime_extensions/tools/pre_post_processing/step.py +227 -0
  33. onnxruntime_extensions/tools/pre_post_processing/steps/__init__.py +6 -0
  34. onnxruntime_extensions/tools/pre_post_processing/steps/general.py +366 -0
  35. onnxruntime_extensions/tools/pre_post_processing/steps/nlp.py +344 -0
  36. onnxruntime_extensions/tools/pre_post_processing/steps/vision.py +1157 -0
  37. onnxruntime_extensions/tools/pre_post_processing/utils.py +139 -0
  38. onnxruntime_extensions/util.py +186 -0
  39. onnxruntime_extensions-0.14.0.dist-info/LICENSE +21 -0
  40. onnxruntime_extensions-0.14.0.dist-info/METADATA +102 -0
  41. onnxruntime_extensions-0.14.0.dist-info/RECORD +43 -0
  42. onnxruntime_extensions-0.14.0.dist-info/WHEEL +6 -0
  43. onnxruntime_extensions-0.14.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,274 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License. See License.txt in the project root for
3
+ # license information.
4
+ ###############################################################################
5
+
6
+ """
7
+ _ortapi2.py: ONNXRuntime-Extensions Python API
8
+ """
9
+
10
+ import copy
11
+ import numpy as np
12
+ from ._ocos import default_opset_domain, get_library_path, Opdef
13
+ from ._cuops import onnx, onnx_proto, SingleOpGraph
14
+
15
+ _ort_check_passed = False
16
+ try:
17
+ from packaging import version as _ver
18
+ import onnxruntime as _ort
19
+
20
+ if _ver.parse(_ort.__version__) >= _ver.parse("1.10.0"):
21
+ _ort_check_passed = True
22
+ except ImportError:
23
+ pass
24
+
25
+ if not _ort_check_passed:
26
+ raise RuntimeError("please install ONNXRuntime/ONNXRuntime-GPU >= 1.10.0")
27
+
28
+
29
+ def _ensure_opset_domain(model):
30
+ op_domain_name = default_opset_domain()
31
+ domain_missing = True
32
+ for oi_ in model.opset_import:
33
+ if oi_.domain == op_domain_name:
34
+ domain_missing = False
35
+
36
+ if domain_missing:
37
+ model.opset_import.extend([onnx.helper.make_operatorsetid(op_domain_name, 1)])
38
+
39
+ return model
40
+
41
+
42
+ def hook_model_op(model, node_name, hook_func, input_types):
43
+ """
44
+ Add a hook function node in the ONNX Model, which could be used for the model diagnosis.
45
+ :param model: The ONNX model loaded as ModelProto
46
+ :param node_name: The node name where the hook will be installed
47
+ :param hook_func: The hook function, callback on the model inference
48
+ :param input_types: The input types as a list
49
+ :return: The ONNX model with the hook installed
50
+ """
51
+
52
+ # onnx.shape_inference is very unstable, useless.
53
+ # hkd_model = shape_inference.infer_shapes(model)
54
+ hkd_model = model
55
+
56
+ n_idx = 0
57
+ hnode, nnode = (None, None)
58
+ nodes = list(hkd_model.graph.node)
59
+ brkpt_name = node_name + "_hkd"
60
+ optype_name = "op_{}_{}".format(hook_func.__name__, node_name)
61
+ for n_ in nodes:
62
+ if n_.name == node_name:
63
+ input_names = list(n_.input)
64
+ brk_output_name = [i_ + "_hkd" for i_ in input_names]
65
+ hnode = onnx.helper.make_node(
66
+ optype_name, n_.input, brk_output_name, name=brkpt_name, domain=default_opset_domain()
67
+ )
68
+ nnode = n_
69
+ del nnode.input[:]
70
+ nnode.input.extend(brk_output_name)
71
+ break
72
+ n_idx += 1
73
+
74
+ if hnode is None:
75
+ raise ValueError("{} is not an operator node name".format(node_name))
76
+
77
+ repacked = nodes[:n_idx] + [hnode, nnode] + nodes[n_idx + 1 :]
78
+ del hkd_model.graph.node[:]
79
+ hkd_model.graph.node.extend(repacked)
80
+
81
+ Opdef.create(hook_func, op_type=optype_name, inputs=input_types, outputs=input_types)
82
+ return _ensure_opset_domain(hkd_model)
83
+
84
+
85
+ def expand_onnx_inputs(model, target_input, extra_nodes, new_inputs):
86
+ """
87
+ Replace the existing inputs of a model with the new inputs, plus some extra nodes
88
+ :param model: The ONNX model loaded as ModelProto
89
+ :param target_input: The input name to be replaced
90
+ :param extra_nodes: The extra nodes to be added
91
+ :param new_inputs: The new input (type: ValueInfoProto) sequence
92
+ :return: The ONNX model after modification
93
+ """
94
+ graph = model.graph
95
+ new_inputs = [n for n in graph.input if n.name != target_input] + new_inputs
96
+ new_nodes = list(model.graph.node) + extra_nodes
97
+ new_graph = onnx.helper.make_graph(new_nodes, graph.name, new_inputs, list(graph.output), list(graph.initializer))
98
+
99
+ new_model = copy.deepcopy(model)
100
+ new_model.graph.CopyFrom(new_graph)
101
+
102
+ return _ensure_opset_domain(new_model)
103
+
104
+
105
+ def get_opset_version_from_ort():
106
+ _ORT_OPSET_SUPPORT_TABLE = {
107
+ "1.5": 11,
108
+ "1.6": 12,
109
+ "1.7": 13,
110
+ "1.8": 14,
111
+ "1.9": 15,
112
+ "1.10": 15,
113
+ "1.11": 16,
114
+ "1.12": 17,
115
+ "1.13": 17,
116
+ "1.14": 18,
117
+ "1.15": 18,
118
+ }
119
+
120
+ ort_ver_string = ".".join(_ort.__version__.split(".")[0:2])
121
+ max_ver = max(_ORT_OPSET_SUPPORT_TABLE, key=_ORT_OPSET_SUPPORT_TABLE.get)
122
+ if ort_ver_string > max_ver:
123
+ ort_ver_string = max_ver
124
+ return _ORT_OPSET_SUPPORT_TABLE.get(ort_ver_string, 11)
125
+
126
+
127
+ def make_onnx_model(graph, opset_version=0, extra_domain=default_opset_domain(), extra_opset_version=1):
128
+ if opset_version == 0:
129
+ opset_version = get_opset_version_from_ort()
130
+ fn_mm = (
131
+ onnx.helper.make_model_gen_version if hasattr(onnx.helper, "make_model_gen_version") else onnx.helper.make_model
132
+ )
133
+ model = fn_mm(graph, opset_imports=[onnx.helper.make_operatorsetid("ai.onnx", opset_version)])
134
+ model.opset_import.extend([onnx.helper.make_operatorsetid(extra_domain, extra_opset_version)])
135
+ return model
136
+
137
+
138
+ class OrtPyFunction:
139
+ """
140
+ OrtPyFunction is a convenience class that serves as a wrapper around the ONNXRuntime InferenceSession,
141
+ equipped with registered onnxruntime-extensions. This allows execution of an ONNX model as if it were a
142
+ standard Python function. The order of the function arguments correlates directly with
143
+ the sequence of the input/output in the ONNX graph.
144
+ """
145
+
146
+ def get_ort_session_options(self):
147
+ so = _ort.SessionOptions()
148
+ for k, v in self.extra_session_options.items():
149
+ so.__setattr__(k, v)
150
+ so.register_custom_ops_library(get_library_path())
151
+ return so
152
+
153
+ def __init__(self, path_or_model=None, cpu_only=None):
154
+ self._onnx_model = None
155
+ self.ort_session = None
156
+ self.default_inputs = {}
157
+ self.execution_providers = ["CPUExecutionProvider"]
158
+ if not cpu_only:
159
+ if _ort.get_device() == "GPU":
160
+ self.execution_providers = ["CUDAExecutionProvider"]
161
+ self.extra_session_options = {}
162
+ mpath = None
163
+ if isinstance(path_or_model, str):
164
+ oxml = onnx.load_model(path_or_model)
165
+ mpath = path_or_model
166
+ else:
167
+ oxml = path_or_model
168
+ if path_or_model is not None:
169
+ self._bind(oxml, mpath)
170
+
171
+ def create_from_customop(self, op_type, *args, **kwargs):
172
+ graph = SingleOpGraph.build_graph(op_type, *args, **kwargs)
173
+ self._bind(make_onnx_model(graph))
174
+ return self
175
+
176
+ def add_default_input(self, **kwargs):
177
+ inputs = {
178
+ ky_: val_ if isinstance(val_, (np.ndarray, np.generic)) else np.asarray(list(val_), dtype=np.uint8)
179
+ for ky_, val_ in kwargs.items()
180
+ }
181
+
182
+ self.default_inputs.update(inputs)
183
+
184
+ @property
185
+ def onnx_model(self):
186
+ assert self._oxml is not None, "No onnx model attached yet."
187
+ return self._oxml
188
+
189
+ @property
190
+ def input_names(self):
191
+ return [vi_.name for vi_ in self.onnx_model.graph.input]
192
+
193
+ @property
194
+ def output_names(self):
195
+ return [vi_.name for vi_ in self.onnx_model.graph.output]
196
+
197
+ def _bind(self, oxml, model_path=None):
198
+ self.inputs = list(oxml.graph.input)
199
+ self.outputs = list(oxml.graph.output)
200
+ self._oxml = oxml
201
+ if model_path is not None:
202
+ self.ort_session = _ort.InferenceSession(
203
+ model_path, self.get_ort_session_options(), self.execution_providers
204
+ )
205
+ return self
206
+
207
+ def _ensure_ort_session(self):
208
+ if self.ort_session is None:
209
+ sess = _ort.InferenceSession(
210
+ self.onnx_model.SerializeToString(), self.get_ort_session_options(), self.execution_providers
211
+ )
212
+ self.ort_session = sess
213
+
214
+ return self.ort_session
215
+
216
+ @staticmethod
217
+ def _get_kwarg_device(kwargs):
218
+ cpuonly = kwargs.get("cpu_only", None)
219
+ if cpuonly is not None:
220
+ del kwargs["cpu_only"]
221
+ return cpuonly
222
+
223
+ @classmethod
224
+ def from_customop(cls, op_type, *args, **kwargs):
225
+ return cls(cpu_only=cls._get_kwarg_device(kwargs)).create_from_customop(op_type, *args, **kwargs)
226
+
227
+ @classmethod
228
+ def from_model(cls, path_or_model, *args, **kwargs):
229
+ fn = cls(path_or_model, cls._get_kwarg_device(kwargs))
230
+ return fn
231
+
232
+ def _argument_map(self, *args, **kwargs):
233
+ idx = 0
234
+ feed = {}
235
+ for i_ in self.inputs:
236
+ if i_.name in self.default_inputs:
237
+ feed[i_.name] = self.default_inputs[i_.name]
238
+ continue
239
+
240
+ x = args[idx]
241
+ ts_x = np.array(x) if isinstance(x, (int, float, bool)) else x
242
+ # numpy by default is int32 in some platforms, sometimes it is int64.
243
+ feed[i_.name] = (
244
+ ts_x.astype(np.int64) if i_.type.tensor_type.elem_type == onnx_proto.TensorProto.INT64 else ts_x
245
+ )
246
+ idx += 1
247
+
248
+ feed.update(kwargs)
249
+ return feed
250
+
251
+ def __call__(self, *args, **kwargs):
252
+ self._ensure_ort_session()
253
+ outputs = self.ort_session.run(None, self._argument_map(*args, **kwargs))
254
+ return outputs[0] if len(outputs) == 1 else tuple(outputs)
255
+
256
+
257
+ def ort_inference(model, *args, cpu_only=True, **kwargs):
258
+ """
259
+ Run an ONNX model with ORT where args are inputs and return values are outputs.
260
+ """
261
+ return OrtPyFunction(model, cpu_only=cpu_only)(*args, **kwargs)
262
+
263
+
264
+ def optimize_model(model_or_file, output_file):
265
+ sess_options = OrtPyFunction().get_ort_session_options()
266
+ sess_options.graph_optimization_level = _ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
267
+ sess_options.optimized_model_filepath = output_file
268
+ _ort.InferenceSession(
269
+ model_or_file if isinstance(model_or_file, str) else model_or_file.SerializeToString(), sess_options
270
+ )
271
+
272
+
273
+ ONNXRuntimeError = _ort.capi.onnxruntime_pybind11_state.Fail
274
+ ONNXRuntimeException = _ort.capi.onnxruntime_pybind11_state.RuntimeException
@@ -0,0 +1,231 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License. See License.txt in the project root for
3
+ # license information.
4
+ ###############################################################################
5
+
6
+ """
7
+ _torch_cvt.py: Data processing graph converted from PyTorch
8
+ """
9
+
10
+ import io
11
+ import onnx
12
+ import torch
13
+ import numpy as np
14
+
15
+ from onnx import numpy_helper
16
+
17
+ from ._ortapi2 import make_onnx_model
18
+ from ._cuops import SingleOpGraph
19
+ from ._hf_cvt import HFTokenizerConverter
20
+ from .util import remove_unused_initializers, mel_filterbank
21
+
22
+
23
+ class _WhisperHParams:
24
+ SAMPLE_RATE = 16000
25
+ N_FFT = 400
26
+ N_MELS = 80
27
+ HOP_LENGTH = 160
28
+ CHUNK_LENGTH = 30
29
+ N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
30
+ N_FRAMES = N_SAMPLES // HOP_LENGTH
31
+
32
+
33
+ class CustomOpStftNorm(torch.autograd.Function):
34
+ @staticmethod
35
+ def symbolic(g, self, n_fft, hop_length, window):
36
+ t_n_fft = g.op('Constant', value_t=torch.tensor(
37
+ n_fft, dtype=torch.int64))
38
+ t_hop_length = g.op('Constant', value_t=torch.tensor(
39
+ hop_length, dtype=torch.int64))
40
+ t_frame_size = g.op(
41
+ 'Constant', value_t=torch.tensor(n_fft, dtype=torch.int64))
42
+ return g.op("ai.onnx.contrib::StftNorm", self, t_n_fft, t_hop_length, window, t_frame_size)
43
+
44
+ @staticmethod
45
+ def forward(ctx, audio, n_fft, hop_length, window):
46
+ win_length = window.shape[0]
47
+ stft = torch.stft(audio, n_fft, hop_length, win_length, window,
48
+ center=True, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
49
+ return stft.abs() ** 2
50
+
51
+
52
+ class WhisperPrePipeline(torch.nn.Module):
53
+ def __init__(self, sr=_WhisperHParams.SAMPLE_RATE, n_fft=_WhisperHParams.N_FFT,
54
+ hop_length=_WhisperHParams.HOP_LENGTH, n_mels=_WhisperHParams.N_MELS,
55
+ n_samples=_WhisperHParams.N_SAMPLES):
56
+ super().__init__()
57
+ self.n_samples = n_samples
58
+ self.hop_length = hop_length
59
+ self.n_fft = n_fft
60
+ self.window = torch.hann_window(n_fft)
61
+ self.mel_filters = torch.from_numpy(
62
+ mel_filterbank(sr=sr, n_fft=n_fft, n_mels=n_mels))
63
+
64
+ def forward(self, audio_pcm: torch.Tensor):
65
+ stft_norm = CustomOpStftNorm.apply(audio_pcm,
66
+ self.n_fft,
67
+ self.hop_length,
68
+ self.window)
69
+ magnitudes = stft_norm[:, :, :-1]
70
+ mel_spec = self.mel_filters @ magnitudes
71
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
72
+ spec_min = log_spec.max() - 8.0
73
+ log_spec = torch.maximum(log_spec, spec_min)
74
+ spec_shape = log_spec.shape
75
+ padding_spec = torch.ones(spec_shape[0],
76
+ spec_shape[1],
77
+ self.n_samples // self.hop_length -
78
+ spec_shape[2],
79
+ dtype=torch.float)
80
+ padding_spec *= spec_min
81
+ log_spec = torch.cat((log_spec, padding_spec), dim=2)
82
+ log_spec = (log_spec + 4.0) / 4.0
83
+ return log_spec
84
+
85
+
86
+ def _to_onnx_stft(onnx_model, n_fft):
87
+ """Convert custom-op STFT-Norm to ONNX STFT"""
88
+ node_idx = 0
89
+ new_stft_nodes = []
90
+ stft_norm_node = None
91
+ for node in onnx_model.graph.node:
92
+ if node.op_type == "StftNorm":
93
+ stft_norm_node = node
94
+ break
95
+ node_idx += 1
96
+
97
+ if stft_norm_node is None:
98
+ raise RuntimeError("Cannot find STFTNorm node in the graph")
99
+
100
+ make_node = onnx.helper.make_node
101
+ replaced_nodes = [
102
+ make_node('Constant', inputs=[], outputs=['const_minus_1_output_0'], name='const_minus_1',
103
+ value=numpy_helper.from_array(np.array([-1], dtype='int64'))),
104
+ make_node('Constant', inputs=[], outputs=['const_14_output_0'], name='const_14',
105
+ value=numpy_helper.from_array(np.array([0,
106
+ n_fft // 2, 0,
107
+ n_fft // 2], dtype='int64'),
108
+ name='const_14')),
109
+ make_node('Pad',
110
+ inputs=[stft_norm_node.input[0], 'const_14_output_0'],
111
+ outputs=['pad_1_output_0'], mode='reflect'),
112
+ make_node('Unsqueeze',
113
+ inputs=['pad_1_output_0', 'const_minus_1_output_0'],
114
+ outputs=['unsqueeze_1_output_0'],
115
+ name='unsqueeze_1'),
116
+ make_node('STFT',
117
+ inputs=['unsqueeze_1_output_0', stft_norm_node.input[2],
118
+ stft_norm_node.input[3], stft_norm_node.input[4]],
119
+ outputs=['stft_output_0'], name='stft', onesided=1),
120
+ make_node('Transpose', inputs=['stft_output_0'], outputs=['transpose_1_output_0'], name='transpose_1',
121
+ perm=[0, 2, 1, 3]),
122
+ make_node('Constant', inputs=[], outputs=['const_17_output_0'], name='const_17',
123
+ value=numpy_helper.from_array(np.array([2], dtype='int64'), name='')),
124
+ make_node('Constant', inputs=[], outputs=['const_18_output_0'], name='const_18',
125
+ value=numpy_helper.from_array(np.array([0], dtype='int64'), name='')),
126
+ make_node('Constant', inputs=[], outputs=['const_20_output_0'], name='const_20',
127
+ value=numpy_helper.from_array(np.array([1], dtype='int64'), name='')),
128
+ make_node('Slice', inputs=['transpose_1_output_0', 'const_18_output_0', 'const_minus_1_output_0',
129
+ 'const_17_output_0', 'const_20_output_0'], outputs=['slice_1_output_0'],
130
+ name='slice_1'),
131
+ make_node('Constant', inputs=[], outputs=[
132
+ 'const0_output_0'], name='const0', value_int=0),
133
+ make_node('Constant', inputs=[], outputs=[
134
+ 'const1_output_0'], name='const1', value_int=1),
135
+ make_node('Gather', inputs=['slice_1_output_0', 'const0_output_0'], outputs=['gather_4_output_0'],
136
+ name='gather_4', axis=3),
137
+ make_node('Gather', inputs=['slice_1_output_0', 'const1_output_0'], outputs=['gather_5_output_0'],
138
+ name='gather_5', axis=3),
139
+ make_node('Mul', inputs=['gather_4_output_0', 'gather_4_output_0'], outputs=[
140
+ 'mul_output_0'], name='mul0'),
141
+ make_node('Mul', inputs=['gather_5_output_0', 'gather_5_output_0'], outputs=[
142
+ 'mul_1_output_0'], name='mul1'),
143
+ make_node('Add', inputs=['mul_output_0', 'mul_1_output_0'], outputs=[
144
+ stft_norm_node.output[0]], name='add0'),
145
+ ]
146
+ new_stft_nodes.extend(onnx_model.graph.node[:node_idx])
147
+ new_stft_nodes.extend(replaced_nodes)
148
+ new_stft_nodes.extend(onnx_model.graph.node[node_idx + 1:])
149
+ del onnx_model.graph.node[:]
150
+ onnx_model.graph.node.extend(new_stft_nodes)
151
+ onnx.checker.check_model(onnx_model)
152
+ return onnx_model
153
+
154
+
155
+ def _torch_export(*arg, **kwargs):
156
+ with io.BytesIO() as f:
157
+ torch.onnx.export(*arg, f, **kwargs)
158
+ return onnx.load_from_string(f.getvalue())
159
+
160
+
161
+ class WhisperDataProcGraph:
162
+ def __init__(self, processor, **kwargs):
163
+ self.hf_processor = processor
164
+ _opset = kwargs.pop('opset', 17)
165
+ self.opset_version = _opset if _opset else 17
166
+
167
+ def pre_processing(self, **kwargs):
168
+ use_audio_decoder = kwargs.pop('USE_AUDIO_DECODER', True)
169
+ use_onnx_stft = kwargs.pop('USE_ONNX_STFT', True)
170
+ feature_extractor = self.hf_processor.feature_extractor
171
+ whisper_processing = WhisperPrePipeline(
172
+ feature_extractor.sampling_rate,
173
+ feature_extractor.n_fft,
174
+ feature_extractor.hop_length,
175
+ feature_extractor.feature_size,
176
+ feature_extractor.n_samples)
177
+
178
+ audio_pcm = torch.rand((1, 32000), dtype=torch.float32)
179
+ model_args = (audio_pcm,)
180
+ pre_model = _torch_export(
181
+ whisper_processing,
182
+ model_args,
183
+ input_names=["audio_pcm"],
184
+ output_names=["log_mel"],
185
+ do_constant_folding=True,
186
+ export_params=True,
187
+ opset_version=self.opset_version,
188
+ dynamic_axes={
189
+ "audio_pcm": {1: "sample_len"},
190
+ }
191
+ )
192
+ if use_onnx_stft:
193
+ pre_model = _to_onnx_stft(pre_model, feature_extractor.n_fft)
194
+ remove_unused_initializers(pre_model.graph)
195
+
196
+ pre_full = pre_model
197
+ if use_audio_decoder:
198
+ audecoder_g = SingleOpGraph.build_graph(
199
+ "AudioDecoder",
200
+ downsampling_rate=feature_extractor.sampling_rate,
201
+ stereo_to_mono=1)
202
+ audecoder_m = make_onnx_model(audecoder_g)
203
+ pre_full = onnx.compose.merge_models(
204
+ audecoder_m,
205
+ pre_model,
206
+ io_map=[("floatPCM", "audio_pcm")])
207
+
208
+ return pre_full
209
+
210
+ def post_processing(self, **kwargs):
211
+ skip_special_tokens = kwargs.get('skip_special_tokens', True)
212
+ g = SingleOpGraph.build_graph(
213
+ "BpeDecoder",
214
+ cvt=HFTokenizerConverter(self.hf_processor.tokenizer).bpe_decoder,
215
+ skip_special_tokens=skip_special_tokens)
216
+
217
+ bpenode = g.node[0]
218
+ bpenode.input[0] = "generated_ids"
219
+ nodes = [onnx.helper.make_node('Cast', ['sequences'], ["generated_ids"], to=onnx.TensorProto.INT64),
220
+ bpenode]
221
+ del g.node[:]
222
+ g.node.extend(nodes)
223
+
224
+ inputs = [onnx.helper.make_tensor_value_info(
225
+ "sequences", onnx.TensorProto.INT32, ['N', 'seq_len', 'ids'])]
226
+ del g.input[:]
227
+ g.input.extend(inputs)
228
+ g.output[0].type.CopyFrom(onnx.helper.make_tensor_type_proto(
229
+ onnx.TensorProto.STRING, ['N', 'text']))
230
+
231
+ return make_onnx_model(g, opset_version=self.opset_version)
@@ -0,0 +1,2 @@
1
+ # Generated by setup.py, DON'T MANUALLY UPDATE IT!
2
+ __version__ = "0.14.0"
@@ -0,0 +1,66 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License. See License.txt in the project root for
3
+ # license information.
4
+ ###############################################################################
5
+
6
+ """
7
+ cmd.py: cli commands for onnxruntime_extensions
8
+ """
9
+
10
+ import os
11
+ import argparse
12
+ import onnx
13
+ import numpy
14
+
15
+ from onnx import onnx_pb, save_tensor, numpy_helper
16
+ from ._ortapi2 import OrtPyFunction
17
+
18
+
19
+ class ORTExtCommands:
20
+ def __init__(self, model='model.onnx', testdata_dir=None) -> None:
21
+ self._model = model
22
+ self._testdata_dir = testdata_dir
23
+
24
+ def run(self, *args):
25
+ """
26
+ Run an onnx model with the arguments as its inputs
27
+ """
28
+ op_func = OrtPyFunction.from_model(self._model)
29
+ np_args = [numpy.asarray(_x) for _x in args]
30
+ for _idx, _sch in enumerate(op_func.inputs):
31
+ if _sch.type.tensor_type.elem_type == onnx_pb.TensorProto.FLOAT:
32
+ np_args[_idx] = np_args[_idx].astype(numpy.float32)
33
+
34
+ print(op_func(*np_args))
35
+ if self._testdata_dir:
36
+ testdir = os.path.expanduser(self._testdata_dir)
37
+ target_dir = os.path.join(testdir, 'test_data_set_0')
38
+ os.makedirs(target_dir, exist_ok=True)
39
+ for _idx, _x in enumerate(np_args):
40
+ fn = os.path.join(target_dir, "input_{}.pb".format(_idx))
41
+ save_tensor(numpy_helper.from_array(_x, op_func.inputs[_idx].name), fn)
42
+ onnx.save_model(op_func.onnx_model, os.path.join(testdir, 'model.onnx'))
43
+
44
+ def selfcheck(self, *args):
45
+ print("The extensions loaded, status: OK.")
46
+
47
+
48
+ def main():
49
+ parser = argparse.ArgumentParser(description="ORT Extension commands")
50
+ parser.add_argument("command", choices=["run", "selfcheck"])
51
+ parser.add_argument("--model", default="model.onnx", help="Path to the ONNX model file")
52
+ parser.add_argument("--testdata-dir", help="Path to the test data directory")
53
+ parser.add_argument("args", nargs=argparse.REMAINDER, help="Additional arguments")
54
+
55
+ args = parser.parse_args()
56
+
57
+ ort_commands = ORTExtCommands(model=args.model, testdata_dir=args.testdata_dir)
58
+
59
+ if args.command == "run":
60
+ ort_commands.run(*args.args)
61
+ elif args.command == "selfcheck":
62
+ ort_commands.selfcheck(*args.args)
63
+
64
+
65
+ if __name__ == '__main__':
66
+ main()