onnxruntime_extensions 0.14.0__cp313-cp313-macosx_11_0_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime_extensions/__init__.py +82 -0
- onnxruntime_extensions/_cuops.py +564 -0
- onnxruntime_extensions/_extensions_pydll.cpython-313-darwin.so +0 -0
- onnxruntime_extensions/_extensions_pydll.pyi +45 -0
- onnxruntime_extensions/_hf_cvt.py +331 -0
- onnxruntime_extensions/_ocos.py +133 -0
- onnxruntime_extensions/_ortapi2.py +274 -0
- onnxruntime_extensions/_torch_cvt.py +231 -0
- onnxruntime_extensions/_version.py +2 -0
- onnxruntime_extensions/cmd.py +66 -0
- onnxruntime_extensions/cvt.py +306 -0
- onnxruntime_extensions/onnxprocess/__init__.py +12 -0
- onnxruntime_extensions/onnxprocess/_builder.py +53 -0
- onnxruntime_extensions/onnxprocess/_onnx_ops.py +1507 -0
- onnxruntime_extensions/onnxprocess/_session.py +355 -0
- onnxruntime_extensions/onnxprocess/_tensor.py +628 -0
- onnxruntime_extensions/onnxprocess/torch_wrapper.py +31 -0
- onnxruntime_extensions/pnp/__init__.py +13 -0
- onnxruntime_extensions/pnp/_base.py +124 -0
- onnxruntime_extensions/pnp/_imagenet.py +65 -0
- onnxruntime_extensions/pnp/_nlp.py +148 -0
- onnxruntime_extensions/pnp/_onnx_ops.py +1544 -0
- onnxruntime_extensions/pnp/_torchext.py +310 -0
- onnxruntime_extensions/pnp/_unifier.py +45 -0
- onnxruntime_extensions/pnp/_utils.py +302 -0
- onnxruntime_extensions/pp_api.py +83 -0
- onnxruntime_extensions/tools/__init__.py +0 -0
- onnxruntime_extensions/tools/add_HuggingFace_CLIPImageProcessor_to_model.py +171 -0
- onnxruntime_extensions/tools/add_pre_post_processing_to_model.py +535 -0
- onnxruntime_extensions/tools/pre_post_processing/__init__.py +4 -0
- onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py +395 -0
- onnxruntime_extensions/tools/pre_post_processing/step.py +227 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/__init__.py +6 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/general.py +366 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/nlp.py +344 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/vision.py +1157 -0
- onnxruntime_extensions/tools/pre_post_processing/utils.py +139 -0
- onnxruntime_extensions/util.py +186 -0
- onnxruntime_extensions-0.14.0.dist-info/LICENSE +21 -0
- onnxruntime_extensions-0.14.0.dist-info/METADATA +102 -0
- onnxruntime_extensions-0.14.0.dist-info/RECORD +43 -0
- onnxruntime_extensions-0.14.0.dist-info/WHEEL +6 -0
- onnxruntime_extensions-0.14.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
import onnx
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
from typing import Any
|
|
5
|
+
from onnx import helper
|
|
6
|
+
from onnx import onnx_pb as onnx_proto
|
|
7
|
+
from distutils.version import LooseVersion
|
|
8
|
+
from torch.onnx import register_custom_op_symbolic
|
|
9
|
+
|
|
10
|
+
from ._utils import ONNXModelUtils
|
|
11
|
+
from ._base import CustomFunction, ProcessingTracedModule, is_processing_module
|
|
12
|
+
from ._onnx_ops import ox as _ox, schema as _schema
|
|
13
|
+
from ._onnx_ops import ONNXElementContainer, make_model_ex
|
|
14
|
+
from .._ortapi2 import OrtPyFunction, get_opset_version_from_ort
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _is_numpy_object(x):
|
|
18
|
+
return isinstance(x, (np.ndarray, np.generic))
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _is_numpy_string_type(arr):
|
|
22
|
+
return arr.dtype.kind in {'U', 'S'}
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _is_string_type(x):
|
|
26
|
+
if isinstance(x, list):
|
|
27
|
+
return any(_is_string_type(e) for e in x)
|
|
28
|
+
elif isinstance(x, torch.Tensor):
|
|
29
|
+
return False
|
|
30
|
+
elif not _is_numpy_object(x):
|
|
31
|
+
x = np.array(x)
|
|
32
|
+
return _is_numpy_string_type(x)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _to_onnx_type(dtype):
|
|
36
|
+
ty_dict = {torch.bool: onnx_proto.TensorProto.BOOL,
|
|
37
|
+
torch.float32: onnx_proto.TensorProto.FLOAT,
|
|
38
|
+
torch.float64: onnx_proto.TensorProto.DOUBLE,
|
|
39
|
+
torch.long: onnx_proto.TensorProto.INT64,
|
|
40
|
+
torch.int32: onnx_proto.TensorProto.INT32}
|
|
41
|
+
# ...
|
|
42
|
+
return ty_dict.get(dtype, onnx_proto.TensorProto.STRING)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class OnnxOpFunction(CustomFunction):
|
|
46
|
+
@classmethod
|
|
47
|
+
def get_next_id_name(cls, name_base):
|
|
48
|
+
name = 'cls' if name_base is None else name_base
|
|
49
|
+
_cid = getattr(cls, '_cid', 1)
|
|
50
|
+
cls._cid = _cid + 1
|
|
51
|
+
return "{}_{}".format(name, _cid)
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def jvp(ctx: Any, *grad_inputs: Any) -> Any:
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
@staticmethod
|
|
58
|
+
def backward(ctx: Any, *grad_outputs: Any) -> Any:
|
|
59
|
+
return grad_outputs
|
|
60
|
+
|
|
61
|
+
@classmethod
|
|
62
|
+
def build_model(cls, opset_version, *args):
|
|
63
|
+
# build the one node graph
|
|
64
|
+
if isinstance(args[0], list):
|
|
65
|
+
args = [np.asarray(_i) for _i in args]
|
|
66
|
+
ec = ONNXElementContainer(get_opset_version_from_ort() if opset_version is None else opset_version)
|
|
67
|
+
attrs = cls.attrs
|
|
68
|
+
vi_inputs = [helper.make_tensor_value_info(
|
|
69
|
+
'it_' + str(id(_arg)), _to_onnx_type(_arg.dtype), list(_arg.shape))
|
|
70
|
+
for _arg in args]
|
|
71
|
+
inputs = [_vi.name for _vi in vi_inputs]
|
|
72
|
+
if hasattr(cls.opb_func, 'outputs') and len(cls.opb_func.outputs) > 0:
|
|
73
|
+
vi_outputs = [helper.make_tensor_value_info(
|
|
74
|
+
cls.get_next_id_name('ot'), *_schm) for _schm in cls.opb_func.outputs]
|
|
75
|
+
else:
|
|
76
|
+
vi_outputs = [helper.make_tensor_value_info(
|
|
77
|
+
cls.get_next_id_name('ot'), onnx_proto.TensorProto.FLOAT, []
|
|
78
|
+
)]
|
|
79
|
+
outputs = [_vi.name for _vi in vi_outputs]
|
|
80
|
+
# build the node
|
|
81
|
+
opfunc = cls.opb_func
|
|
82
|
+
opfunc(inputs, outputs, ec, None, **attrs)
|
|
83
|
+
g = helper.make_graph(ec.nodes, cls.get_next_id_name('g'), vi_inputs, vi_outputs)
|
|
84
|
+
m = make_model_ex(g, ec.node_domain_version_pair_sets, ec.target_opset)
|
|
85
|
+
return m
|
|
86
|
+
|
|
87
|
+
@classmethod
|
|
88
|
+
@torch.jit.unused
|
|
89
|
+
def _onnx_call(cls, ctx, *args) -> Any:
|
|
90
|
+
m = cls.build_model(None, *args)
|
|
91
|
+
try:
|
|
92
|
+
f = OrtPyFunction.from_model(m)
|
|
93
|
+
result = f(*list(_i.numpy() if isinstance(_i, torch.Tensor) else _i for _i in args))
|
|
94
|
+
except Exception as e:
|
|
95
|
+
onnx.save_model(m, '_temp_debugging.onnx')
|
|
96
|
+
raise e
|
|
97
|
+
|
|
98
|
+
results = result if isinstance(result, tuple) else [result]
|
|
99
|
+
return tuple([torch.from_numpy(_o) for _o in results]) if len(results) > 1 else torch.from_numpy(results[0])
|
|
100
|
+
|
|
101
|
+
@classmethod
|
|
102
|
+
def forward(cls, ctx: Any, *args: Any, **kwargs: Any) -> Any:
|
|
103
|
+
return cls._onnx_call(ctx, *args, **kwargs)
|
|
104
|
+
|
|
105
|
+
@classmethod
|
|
106
|
+
def symbolic(cls, g, *args):
|
|
107
|
+
return g.op(cls.op_type, *args)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def create_op_function(op_type: str, func, **attrs):
|
|
111
|
+
if _ox.is_raw(func):
|
|
112
|
+
func = _schema(func.__func__)
|
|
113
|
+
cls = type(_ox.get_unique_operator_type_name(op_type), (OnnxOpFunction,),
|
|
114
|
+
dict(
|
|
115
|
+
op_type=op_type,
|
|
116
|
+
opb_func=func,
|
|
117
|
+
attrs=attrs
|
|
118
|
+
))
|
|
119
|
+
return cls.apply # noqa
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
onnx_pad = create_op_function('Pad', _ox.pad)
|
|
123
|
+
onnx_where = create_op_function('Where', _ox.where)
|
|
124
|
+
onnx_greater = create_op_function('Greater', _ox.greater)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class _OnnxModelFunction:
|
|
128
|
+
id_object_map = {} # cannot use the string directly since jit.script doesn't support the data type
|
|
129
|
+
id_function_map = {}
|
|
130
|
+
str_model_function_id = '_model_function_id'
|
|
131
|
+
str_model_id = '_model_id'
|
|
132
|
+
str_model_attached = '_model_attached'
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
@torch.jit.ignore
|
|
136
|
+
def _invoke_onnx_model(model_id: int, *args, **kwargs):
|
|
137
|
+
func = _OnnxModelFunction.id_function_map.get(model_id, None)
|
|
138
|
+
if not func:
|
|
139
|
+
model_or_path = _OnnxModelFunction.id_object_map.get(model_id)
|
|
140
|
+
if model_or_path is None:
|
|
141
|
+
raise ValueError("cannot find id={} registered!".format(model_id))
|
|
142
|
+
func = OrtPyFunction.from_model(model_or_path)
|
|
143
|
+
_OnnxModelFunction.id_function_map[model_id] = func
|
|
144
|
+
results = func(*list(_i.numpy() if isinstance(_i, torch.Tensor) else _i for _i in args), **kwargs)
|
|
145
|
+
return tuple(
|
|
146
|
+
[torch.from_numpy(_o) for _o in results]) if isinstance(results, tuple) else torch.from_numpy(results)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
@torch.jit.ignore
|
|
150
|
+
def invoke_onnx_model1(model_id: int, arg0):
|
|
151
|
+
return _invoke_onnx_model(model_id, arg0)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
@torch.jit.ignore
|
|
155
|
+
def invoke_onnx_model2(model_id: int, arg0, arg1):
|
|
156
|
+
return _invoke_onnx_model(model_id, arg0, arg1)
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
@torch.jit.ignore
|
|
160
|
+
def invoke_onnx_model3(model_id: int, arg0, arg1, arg2):
|
|
161
|
+
return _invoke_onnx_model(model_id, arg0, arg1, arg2)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class _OnnxTracedFunction(CustomFunction):
|
|
165
|
+
@classmethod
|
|
166
|
+
def forward(cls, ctx: Any, *args: Any, **kwargs: Any) -> Any:
|
|
167
|
+
return _invoke_onnx_model(args[0].item(), *args[1:], **kwargs)
|
|
168
|
+
|
|
169
|
+
@classmethod
|
|
170
|
+
def symbolic(cls, g, *args):
|
|
171
|
+
ret = g.op('ai.onnx.contrib::_ModelFunctionCall', *args)
|
|
172
|
+
model_id = torch.onnx.symbolic_helper._maybe_get_scalar(args[0]) # noqa
|
|
173
|
+
if not model_id:
|
|
174
|
+
return ret
|
|
175
|
+
|
|
176
|
+
func = _OnnxModelFunction.id_function_map.get(model_id.item(), None)
|
|
177
|
+
if not func or len(func.outputs) <= 1:
|
|
178
|
+
return ret
|
|
179
|
+
|
|
180
|
+
outputs = [ret]
|
|
181
|
+
for _ in range(len(func.outputs) - 1):
|
|
182
|
+
outputs.append(ret.node().addOutput())
|
|
183
|
+
|
|
184
|
+
return tuple(outputs)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def create_model_function(model_or_path):
|
|
188
|
+
_id = id(model_or_path)
|
|
189
|
+
assert _id != 0, "internal error: the id of a Python object is 0."
|
|
190
|
+
_OnnxModelFunction.id_object_map[_id] = model_or_path
|
|
191
|
+
return _id
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def get_id_models():
|
|
195
|
+
return _OnnxModelFunction.id_object_map
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class OnnxTracedModelFunction:
|
|
199
|
+
def __init__(self, onnx_model):
|
|
200
|
+
self.func_id = create_model_function(onnx_model)
|
|
201
|
+
|
|
202
|
+
def __call__(self, *args, **kwargs):
|
|
203
|
+
return _OnnxTracedFunction.apply(torch.tensor(self.func_id), *args, **kwargs)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class _OnnxModelModule(torch.nn.Module):
|
|
207
|
+
def __init__(self, mdl):
|
|
208
|
+
super(_OnnxModelModule, self).__init__()
|
|
209
|
+
self.function = OnnxTracedModelFunction(mdl)
|
|
210
|
+
|
|
211
|
+
def forward(self, *args):
|
|
212
|
+
return self.function(*args)
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
def _symbolic_pythonop(g: torch._C.Graph, n: torch._C.Node, *args, **kwargs):
|
|
216
|
+
name = kwargs["name"]
|
|
217
|
+
if name.startswith(invoke_onnx_model1.__name__[:-1]):
|
|
218
|
+
# NB: if you want to get the value of the first argument, i.e. the model id,
|
|
219
|
+
# you can get it by torch.onnx.symbolic_helper._maybe_get_scalar(args[0]).item()
|
|
220
|
+
ret = g.op("ai.onnx.contrib::_ModelFunctionCall", *args)
|
|
221
|
+
else:
|
|
222
|
+
# Logs a warning and returns None
|
|
223
|
+
import warnings
|
|
224
|
+
return warnings.warn("prim::PythonOp", "unknown node kind: " + name)
|
|
225
|
+
# Copy type and shape from original node.
|
|
226
|
+
ret.setType(n.output().type())
|
|
227
|
+
return ret
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
if LooseVersion(torch.__version__) >= LooseVersion("1.11"):
|
|
231
|
+
register_custom_op_symbolic("prim::PythonOp", _symbolic_pythonop, 1)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
class SequentialProcessingModule(ProcessingTracedModule):
|
|
235
|
+
def __init__(self, *models):
|
|
236
|
+
super(SequentialProcessingModule, self).__init__()
|
|
237
|
+
self.model_list = torch.nn.ModuleList()
|
|
238
|
+
for mdl_ in models:
|
|
239
|
+
if isinstance(mdl_, onnx.ModelProto):
|
|
240
|
+
self.model_list.append(_OnnxModelModule(mdl_))
|
|
241
|
+
elif is_processing_module(mdl_):
|
|
242
|
+
self.model_list.append(mdl_)
|
|
243
|
+
else:
|
|
244
|
+
assert callable(mdl_), "the model type is not recognizable."
|
|
245
|
+
self.model_list.append(ProcessingTracedModule(mdl_))
|
|
246
|
+
|
|
247
|
+
def forward(self, *args):
|
|
248
|
+
outputs = args
|
|
249
|
+
with torch.no_grad():
|
|
250
|
+
for idx_, mdl_ in enumerate(self.model_list):
|
|
251
|
+
if not isinstance(outputs, tuple):
|
|
252
|
+
outputs = (outputs,)
|
|
253
|
+
outputs = mdl_(*outputs)
|
|
254
|
+
|
|
255
|
+
return outputs
|
|
256
|
+
|
|
257
|
+
def export(self, *args, **kwargs):
|
|
258
|
+
prefix_m = None
|
|
259
|
+
core_m = self
|
|
260
|
+
raw_input_flag = any(_is_string_type(x_) for x_ in args)
|
|
261
|
+
if raw_input_flag:
|
|
262
|
+
# NB: torch.onnx.export doesn't support exporting a module accepting string type input,
|
|
263
|
+
# So, in this case, the module will be separated into two parts to use the customized export.
|
|
264
|
+
m0 = self.model_list[0]
|
|
265
|
+
new_args = m0(*args)
|
|
266
|
+
if not isinstance(new_args, tuple):
|
|
267
|
+
new_args = (new_args, )
|
|
268
|
+
prefix_m = m0.export(*args, **kwargs)
|
|
269
|
+
args = new_args
|
|
270
|
+
core_m = SequentialProcessingModule(*self.model_list[1:])
|
|
271
|
+
if prefix_m is None:
|
|
272
|
+
return super().export(*args, **kwargs)
|
|
273
|
+
else:
|
|
274
|
+
oxml = core_m.export(*args, **kwargs)
|
|
275
|
+
model = ONNXModelUtils.join_models(prefix_m, oxml)
|
|
276
|
+
|
|
277
|
+
# Rename the input/output node names if the user has provided any substitutions!
|
|
278
|
+
# Ref: https://github.com/onnx/onnx/issues/2052
|
|
279
|
+
# Known issue: This logic doesn't deal with subgraphs.
|
|
280
|
+
if (('input_names' in kwargs) or ('output_names' in kwargs)) and \
|
|
281
|
+
(kwargs['input_names'] or kwargs['output_names']):
|
|
282
|
+
swaps = {}
|
|
283
|
+
if 'input_names' in kwargs and kwargs['input_names']:
|
|
284
|
+
assert len(model.graph.input) == len(kwargs['input_names']), \
|
|
285
|
+
"Expecting {} input names but got {}".format(
|
|
286
|
+
len(model.graph.input), len(kwargs['input_names']))
|
|
287
|
+
for n, new_name in zip(model.graph.input, kwargs['input_names']):
|
|
288
|
+
swaps[n.name] = new_name
|
|
289
|
+
n.name = new_name
|
|
290
|
+
|
|
291
|
+
if 'output_names' in kwargs and kwargs['output_names']:
|
|
292
|
+
assert len(model.graph.output) == len(kwargs['output_names']), \
|
|
293
|
+
"Expecting {} output names but got {}".format(
|
|
294
|
+
len(model.graph.output), len(kwargs['output_names']))
|
|
295
|
+
for n, new_name in zip(model.graph.output, kwargs['output_names']):
|
|
296
|
+
swaps[n.name] = new_name
|
|
297
|
+
n.name = new_name
|
|
298
|
+
|
|
299
|
+
if swaps:
|
|
300
|
+
for n in model.graph.node:
|
|
301
|
+
for j in range(len(n.input)):
|
|
302
|
+
n.input[j] = swaps.get(n.input[j], n.input[j])
|
|
303
|
+
|
|
304
|
+
for j in range(len(n.output)):
|
|
305
|
+
n.output[j] = swaps.get(n.output[j], n.output[j])
|
|
306
|
+
|
|
307
|
+
for n in model.graph.initializer:
|
|
308
|
+
n.name = swaps.get(n.name, n.name)
|
|
309
|
+
|
|
310
|
+
return model
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
import onnx
|
|
2
|
+
|
|
3
|
+
from .._ortapi2 import get_opset_version_from_ort
|
|
4
|
+
from ._utils import ONNXModelUtils
|
|
5
|
+
from ._base import is_processing_module
|
|
6
|
+
from ._torchext import get_id_models, SequentialProcessingModule
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def export(m, *args,
|
|
10
|
+
opset_version=0,
|
|
11
|
+
output_path=None,
|
|
12
|
+
export_params=True,
|
|
13
|
+
verbose=False,
|
|
14
|
+
input_names=None,
|
|
15
|
+
output_names=None,
|
|
16
|
+
operator_export_type=None,
|
|
17
|
+
do_constant_folding=True,
|
|
18
|
+
dynamic_axes=None,
|
|
19
|
+
keep_initializers_as_inputs=None,
|
|
20
|
+
custom_opsets=None,
|
|
21
|
+
io_mapping=None):
|
|
22
|
+
"""
|
|
23
|
+
export all models and modules into a merged ONNX model.
|
|
24
|
+
"""
|
|
25
|
+
if opset_version == 0:
|
|
26
|
+
opset_version = get_opset_version_from_ort()
|
|
27
|
+
|
|
28
|
+
if not is_processing_module(m):
|
|
29
|
+
m = SequentialProcessingModule(m)
|
|
30
|
+
|
|
31
|
+
model = m.export(*args, opset_version=opset_version,
|
|
32
|
+
output_path=output_path,
|
|
33
|
+
export_params=export_params,
|
|
34
|
+
verbose=verbose,
|
|
35
|
+
input_names=input_names,
|
|
36
|
+
output_names=output_names,
|
|
37
|
+
operator_export_type=operator_export_type,
|
|
38
|
+
do_constant_folding=do_constant_folding,
|
|
39
|
+
dynamic_axes=dynamic_axes,
|
|
40
|
+
keep_initializers_as_inputs=keep_initializers_as_inputs,
|
|
41
|
+
custom_opsets=custom_opsets)
|
|
42
|
+
full_m = ONNXModelUtils.unfold_model(model, get_id_models(), io_mapping)
|
|
43
|
+
if output_path is not None:
|
|
44
|
+
onnx.save_model(full_m, output_path)
|
|
45
|
+
return full_m
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import onnx
|
|
3
|
+
from onnx import helper, numpy_helper
|
|
4
|
+
from collections import namedtuple
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class _Container:
|
|
8
|
+
def __init__(self):
|
|
9
|
+
self.parent = None
|
|
10
|
+
self.initializer=[]
|
|
11
|
+
self.value_info=[]
|
|
12
|
+
self.nodes = []
|
|
13
|
+
self.node_domain_version_pair_sets = {}
|
|
14
|
+
|
|
15
|
+
def add_model(self, oxml):
|
|
16
|
+
self.initializer.extend(oxml.graph.initializer)
|
|
17
|
+
self.value_info.extend(oxml.graph.value_info)
|
|
18
|
+
self.nodes.extend(oxml.graph.node)
|
|
19
|
+
self.node_domain_version_pair_sets.update(
|
|
20
|
+
[(opset_.domain, opset_.version) for opset_ in oxml.opset_import])
|
|
21
|
+
return self
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class ONNXModelUtils:
|
|
25
|
+
@staticmethod
|
|
26
|
+
def merge_name(prefix, name):
|
|
27
|
+
return "{}_{}".format(prefix, name)
|
|
28
|
+
|
|
29
|
+
@staticmethod
|
|
30
|
+
def _rename_iter(iterables, prefix_name, inplace=False):
|
|
31
|
+
new_iz = iterables if inplace else [copy.deepcopy(iz_) for iz_ in iterables]
|
|
32
|
+
for iz_ in new_iz:
|
|
33
|
+
iz_.name = ONNXModelUtils.merge_name(prefix_name, iz_.name)
|
|
34
|
+
return new_iz
|
|
35
|
+
|
|
36
|
+
@classmethod
|
|
37
|
+
def _rename_graph(cls, graph, prefix, graph_or_container):
|
|
38
|
+
def io_rename(node, prefix_name, idx):
|
|
39
|
+
new_node = copy.deepcopy(node)
|
|
40
|
+
if not node.name:
|
|
41
|
+
new_node.name = cls.merge_name(prefix_name, "op{}".format(idx))
|
|
42
|
+
else:
|
|
43
|
+
new_node.name = cls.merge_name(prefix_name, node.name)
|
|
44
|
+
|
|
45
|
+
del new_node.input[:]
|
|
46
|
+
new_node.input.extend(ONNXModelUtils.merge_name(prefix_name, nm_) if nm_ else '' for nm_ in node.input)
|
|
47
|
+
del new_node.output[:]
|
|
48
|
+
new_node.output.extend(ONNXModelUtils.merge_name(prefix_name, nm_) if nm_ else '' for nm_ in node.output)
|
|
49
|
+
return new_node
|
|
50
|
+
|
|
51
|
+
assert prefix is not None, 'The graph prefix could not be None'
|
|
52
|
+
graph_or_container.initializer.extend(cls._rename_iter(graph.initializer, prefix))
|
|
53
|
+
graph_or_container.value_info.extend(cls._rename_iter(graph.value_info, prefix))
|
|
54
|
+
return list(io_rename(nd_, prefix, idx_) for idx_, nd_ in enumerate(graph.node))
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def _process_node_body(cls, node, prefix):
|
|
58
|
+
if all(attr.name != 'body' for attr in node.attribute):
|
|
59
|
+
return node
|
|
60
|
+
|
|
61
|
+
def _process_attr(attr, prefix_name):
|
|
62
|
+
if attr.name == 'body':
|
|
63
|
+
new_attr = copy.deepcopy(attr)
|
|
64
|
+
del new_attr.g.value_info[:]
|
|
65
|
+
del new_attr.g.node[:]
|
|
66
|
+
new_attr.g.node.extend(cls._rename_graph(attr.g, prefix_name, new_attr.g))
|
|
67
|
+
cls._rename_iter(new_attr.g.input, prefix_name, inplace=True)
|
|
68
|
+
cls._rename_iter(new_attr.g.output, prefix_name, inplace=True)
|
|
69
|
+
return new_attr
|
|
70
|
+
else:
|
|
71
|
+
return attr
|
|
72
|
+
|
|
73
|
+
attr_list = list(_process_attr(attr_, prefix) for attr_ in node.attribute)
|
|
74
|
+
del node.attribute[:]
|
|
75
|
+
node.attribute.extend(attr_list)
|
|
76
|
+
return node
|
|
77
|
+
|
|
78
|
+
@staticmethod
|
|
79
|
+
def get_model_name_abbr(node):
|
|
80
|
+
no = node.name.split('_')[-1]
|
|
81
|
+
return 'm_' + no
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
def get_model_id_from_arg0(nodes, node):
|
|
85
|
+
arg0_name = node.input[0]
|
|
86
|
+
c_node = [n_ for n_ in nodes if
|
|
87
|
+
n_.op_type == 'Constant' and n_.output[0] == arg0_name]
|
|
88
|
+
assert len(c_node) == 1, 'internal error, multiple nodes with the same output.'
|
|
89
|
+
c_node = c_node[0]
|
|
90
|
+
tensor_value = onnx.helper.get_attribute_value(c_node.attribute[0])
|
|
91
|
+
_id = numpy_helper.to_array(tensor_value).item()
|
|
92
|
+
return _id
|
|
93
|
+
|
|
94
|
+
@classmethod
|
|
95
|
+
def _unfold_model_node(cls, container, name, model, io_mapping=None):
|
|
96
|
+
top_container = container
|
|
97
|
+
while top_container.parent is not None: # only one opset_import in the model.
|
|
98
|
+
top_container = top_container.parent
|
|
99
|
+
|
|
100
|
+
renamed_nodes = cls._rename_graph(model.graph, name, container)
|
|
101
|
+
onnx_nodes = [cls._process_node_body(nd_, name) for nd_ in renamed_nodes]
|
|
102
|
+
|
|
103
|
+
top_container.node_domain_version_pair_sets.update(
|
|
104
|
+
[(opset_.domain, opset_.version) for opset_ in model.opset_import])
|
|
105
|
+
return onnx_nodes
|
|
106
|
+
|
|
107
|
+
@classmethod
|
|
108
|
+
def unfold_model(cls, oxml, id_to_model, io_mapping=None):
|
|
109
|
+
container = _Container().add_model(oxml)
|
|
110
|
+
nodes = []
|
|
111
|
+
for _nid, _node in enumerate(oxml.graph.node):
|
|
112
|
+
if _node.op_type != '_ModelFunctionCall':
|
|
113
|
+
nodes.append(_node)
|
|
114
|
+
else:
|
|
115
|
+
model_id = cls.get_model_id_from_arg0(list(oxml.graph.node), _node)
|
|
116
|
+
if model_id not in id_to_model:
|
|
117
|
+
raise RuntimeError("Cannot find the model id({}) in the table".format(model_id))
|
|
118
|
+
|
|
119
|
+
prefix = cls.get_model_name_abbr(_node)
|
|
120
|
+
nest_model = id_to_model[model_id]
|
|
121
|
+
|
|
122
|
+
input_mapping = []
|
|
123
|
+
output_mapping = []
|
|
124
|
+
for idx_, in_ in enumerate(nest_model.graph.input):
|
|
125
|
+
_renamed_in = "{}_{}".format(prefix, in_.name)
|
|
126
|
+
_nd = onnx.helper.make_node('Identity',
|
|
127
|
+
[_node.input[idx_ + 1]], # the first arg is model id, skip it.
|
|
128
|
+
[_renamed_in],
|
|
129
|
+
name='i_' + _renamed_in)
|
|
130
|
+
input_mapping.append(_nd)
|
|
131
|
+
nds = cls._unfold_model_node(container,
|
|
132
|
+
prefix,
|
|
133
|
+
nest_model,
|
|
134
|
+
io_mapping)
|
|
135
|
+
for idx_, out_ in enumerate(nest_model.graph.output):
|
|
136
|
+
if idx_ >= len(_node.output):
|
|
137
|
+
continue
|
|
138
|
+
_renamed_out = "{}_{}".format(prefix, out_.name)
|
|
139
|
+
_nd = onnx.helper.make_node('Identity',
|
|
140
|
+
[_renamed_out],
|
|
141
|
+
[_node.output[idx_]],
|
|
142
|
+
name='o_' + _renamed_out)
|
|
143
|
+
output_mapping.append(_nd)
|
|
144
|
+
if io_mapping is not None:
|
|
145
|
+
assert callable(io_mapping), "io_mapping is a custom function to build the linkage of the models"
|
|
146
|
+
input_mapping, output_mapping = io_mapping(input_mapping, output_mapping)
|
|
147
|
+
# attention: the order of the list operations is important, which avoids the topological sort.
|
|
148
|
+
nodes.extend(input_mapping)
|
|
149
|
+
nodes.extend(nds)
|
|
150
|
+
nodes.extend(output_mapping)
|
|
151
|
+
|
|
152
|
+
intlzs = cls._remove_unused_initializers(nodes, container.initializer)
|
|
153
|
+
oxml = copy.deepcopy(oxml)
|
|
154
|
+
del oxml.graph.node[:]
|
|
155
|
+
oxml.graph.node.extend(nodes)
|
|
156
|
+
del oxml.graph.initializer[:]
|
|
157
|
+
oxml.graph.initializer.extend(intlzs)
|
|
158
|
+
return oxml
|
|
159
|
+
|
|
160
|
+
@classmethod
|
|
161
|
+
def topological_sort(cls, container, nodes, inputs, outputs):
|
|
162
|
+
op_output_map = {}
|
|
163
|
+
DynNode = namedtuple('DynNode', ['name', 'output'])
|
|
164
|
+
input_nodes = [DynNode(name='placeholder',
|
|
165
|
+
output=[nm_.name for nm_ in inputs] +
|
|
166
|
+
[it_.name for it_ in container.initializers])] + \
|
|
167
|
+
[nd_ for nd_ in nodes if nd_.op_type == 'Constant']
|
|
168
|
+
|
|
169
|
+
for nd_ in nodes + input_nodes:
|
|
170
|
+
for ky_ in nd_.output:
|
|
171
|
+
op_output_map[ky_] = nd_
|
|
172
|
+
|
|
173
|
+
edges = {}
|
|
174
|
+
for op in nodes:
|
|
175
|
+
for x in op.input:
|
|
176
|
+
if x == '':
|
|
177
|
+
continue
|
|
178
|
+
try:
|
|
179
|
+
predecessor = op_output_map[x]
|
|
180
|
+
except KeyError:
|
|
181
|
+
raise RuntimeError(
|
|
182
|
+
"{}: cannot find an operator to produce the tensor: {}".format(op.name, x)) from None
|
|
183
|
+
|
|
184
|
+
val = edges.get(predecessor.name, [])
|
|
185
|
+
val.append(op)
|
|
186
|
+
edges[predecessor.name] = val
|
|
187
|
+
|
|
188
|
+
for y_ in outputs:
|
|
189
|
+
op = op_output_map[y_.name].name
|
|
190
|
+
if op not in edges:
|
|
191
|
+
edges[op] = []
|
|
192
|
+
|
|
193
|
+
visited = set()
|
|
194
|
+
sorted_nodes = []
|
|
195
|
+
unfinished_nodes = set()
|
|
196
|
+
|
|
197
|
+
def recursive_helper(node):
|
|
198
|
+
if node.name in visited:
|
|
199
|
+
return
|
|
200
|
+
|
|
201
|
+
if node.name in unfinished_nodes:
|
|
202
|
+
raise RuntimeError("ONNX Graph is not a DAG, the cycle is found at {}".format(node.name))
|
|
203
|
+
|
|
204
|
+
unfinished_nodes.add(node.name)
|
|
205
|
+
if node.name in edges: # if the node's output is not in the Graph output.
|
|
206
|
+
assert node.name != '', 'this topological-sort depends on the unique node name.'
|
|
207
|
+
for successor in edges[node.name]:
|
|
208
|
+
recursive_helper(successor)
|
|
209
|
+
|
|
210
|
+
unfinished_nodes.remove(node.name)
|
|
211
|
+
visited.add(node.name)
|
|
212
|
+
if node is not input_nodes[0]:
|
|
213
|
+
sorted_nodes.insert(0, node)
|
|
214
|
+
|
|
215
|
+
for nd_ in input_nodes:
|
|
216
|
+
recursive_helper(nd_)
|
|
217
|
+
|
|
218
|
+
return sorted_nodes
|
|
219
|
+
|
|
220
|
+
@staticmethod
|
|
221
|
+
def _remove_unused_initializers(nodes, initializers, reserved_names=None):
|
|
222
|
+
if reserved_names is None:
|
|
223
|
+
reserved_names = set()
|
|
224
|
+
nodes_input_set = set()
|
|
225
|
+
for nd_ in nodes:
|
|
226
|
+
nodes_input_set.update(n_ for n_ in nd_.input)
|
|
227
|
+
|
|
228
|
+
return [intlz_ for intlz_ in initializers if intlz_.name in nodes_input_set or intlz_.name in reserved_names]
|
|
229
|
+
|
|
230
|
+
@classmethod
|
|
231
|
+
def join_models(cls, *models, io_mapping=None):
|
|
232
|
+
# generate the prefix id for the embedding graph to avoid the name conflict
|
|
233
|
+
mdl_prefix = []
|
|
234
|
+
for _i in range(len(models)):
|
|
235
|
+
mdl_prefix.append("g{}".format(_i + 1))
|
|
236
|
+
|
|
237
|
+
inputs = cls._rename_iter(models[0].graph.input, mdl_prefix[0])
|
|
238
|
+
outputs = cls._rename_iter(models[-1].graph.output, mdl_prefix[-1])
|
|
239
|
+
|
|
240
|
+
port_mapping = {}
|
|
241
|
+
if io_mapping is not None:
|
|
242
|
+
assert callable(io_mapping), "io_mapping is a custom function to build the linkage of the models"
|
|
243
|
+
ModelPort = namedtuple('ModelPort', "input output")
|
|
244
|
+
ports = []
|
|
245
|
+
for _idx in range(len(models)):
|
|
246
|
+
mio = ModelPort([cls.merge_name(mdl_prefix[_idx], _x.name) for _x in models[_idx].graph.input],
|
|
247
|
+
[cls.merge_name(mdl_prefix[_idx], _y.name) for _y in models[_idx].graph.output])
|
|
248
|
+
ports.append(mio)
|
|
249
|
+
port_mapping = io_mapping(ports)
|
|
250
|
+
for _idx in range(len(models) - 1):
|
|
251
|
+
for _i, _x in enumerate(models[_idx + 1].graph.input):
|
|
252
|
+
iname = cls.merge_name(mdl_prefix[_idx + 1], _x.name)
|
|
253
|
+
if iname not in port_mapping:
|
|
254
|
+
oname = cls.merge_name(mdl_prefix[_idx], models[_idx].graph.output[_i].name)
|
|
255
|
+
port_mapping[iname] = oname
|
|
256
|
+
|
|
257
|
+
nodes = []
|
|
258
|
+
container = _Container()
|
|
259
|
+
for _idx, _m in enumerate(models):
|
|
260
|
+
container.add_model(_m)
|
|
261
|
+
nodes += cls._rename_graph(_m.graph, mdl_prefix[_idx], container)
|
|
262
|
+
|
|
263
|
+
for _n in nodes:
|
|
264
|
+
replaceable = False
|
|
265
|
+
for _i in _n.input:
|
|
266
|
+
if _i in port_mapping:
|
|
267
|
+
replaceable = True
|
|
268
|
+
break
|
|
269
|
+
if replaceable:
|
|
270
|
+
new_input = copy.deepcopy(_n.input)
|
|
271
|
+
del _n.input[:]
|
|
272
|
+
_n.input.extend([port_mapping[_i] if _i in port_mapping else _i for _i in new_input])
|
|
273
|
+
|
|
274
|
+
name = "_".join([_mdl.graph.name for _mdl in models])
|
|
275
|
+
domains = set()
|
|
276
|
+
_opset = []
|
|
277
|
+
for _mdl in models:
|
|
278
|
+
for _ops in _mdl.opset_import:
|
|
279
|
+
domain = _ops.domain if _ops.domain else "ai.onnx"
|
|
280
|
+
if domain in domains:
|
|
281
|
+
if domain == "ai.onnx":
|
|
282
|
+
assert _ops.version == _opset[0].version, \
|
|
283
|
+
f"ai.onnx domain version doesn't match {_ops.version} != {_opset[0].version}"
|
|
284
|
+
else:
|
|
285
|
+
domains.add(domain)
|
|
286
|
+
if domain == "ai.onnx":
|
|
287
|
+
_opset.insert(0, _ops)
|
|
288
|
+
else:
|
|
289
|
+
_opset.append(_ops)
|
|
290
|
+
|
|
291
|
+
inits = cls._remove_unused_initializers(nodes, container.initializer)
|
|
292
|
+
g = helper.make_graph(nodes, name, inputs, outputs,
|
|
293
|
+
initializer=inits,
|
|
294
|
+
value_info=container.value_info)
|
|
295
|
+
|
|
296
|
+
if hasattr(helper, 'make_model_gen_version'):
|
|
297
|
+
# make_model_gen_version doesn't accept the custom domain.
|
|
298
|
+
m = helper.make_model_gen_version(g, opset_imports=_opset[:1])
|
|
299
|
+
m.opset_import.extend(_opset[1:])
|
|
300
|
+
else:
|
|
301
|
+
m = helper.make_model(g, opset_imports=_opset)
|
|
302
|
+
return m
|