onnxruntime_extensions 0.14.0__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnxruntime_extensions/__init__.py +82 -0
  2. onnxruntime_extensions/_cuops.py +564 -0
  3. onnxruntime_extensions/_extensions_pydll.cpython-313-darwin.so +0 -0
  4. onnxruntime_extensions/_extensions_pydll.pyi +45 -0
  5. onnxruntime_extensions/_hf_cvt.py +331 -0
  6. onnxruntime_extensions/_ocos.py +133 -0
  7. onnxruntime_extensions/_ortapi2.py +274 -0
  8. onnxruntime_extensions/_torch_cvt.py +231 -0
  9. onnxruntime_extensions/_version.py +2 -0
  10. onnxruntime_extensions/cmd.py +66 -0
  11. onnxruntime_extensions/cvt.py +306 -0
  12. onnxruntime_extensions/onnxprocess/__init__.py +12 -0
  13. onnxruntime_extensions/onnxprocess/_builder.py +53 -0
  14. onnxruntime_extensions/onnxprocess/_onnx_ops.py +1507 -0
  15. onnxruntime_extensions/onnxprocess/_session.py +355 -0
  16. onnxruntime_extensions/onnxprocess/_tensor.py +628 -0
  17. onnxruntime_extensions/onnxprocess/torch_wrapper.py +31 -0
  18. onnxruntime_extensions/pnp/__init__.py +13 -0
  19. onnxruntime_extensions/pnp/_base.py +124 -0
  20. onnxruntime_extensions/pnp/_imagenet.py +65 -0
  21. onnxruntime_extensions/pnp/_nlp.py +148 -0
  22. onnxruntime_extensions/pnp/_onnx_ops.py +1544 -0
  23. onnxruntime_extensions/pnp/_torchext.py +310 -0
  24. onnxruntime_extensions/pnp/_unifier.py +45 -0
  25. onnxruntime_extensions/pnp/_utils.py +302 -0
  26. onnxruntime_extensions/pp_api.py +83 -0
  27. onnxruntime_extensions/tools/__init__.py +0 -0
  28. onnxruntime_extensions/tools/add_HuggingFace_CLIPImageProcessor_to_model.py +171 -0
  29. onnxruntime_extensions/tools/add_pre_post_processing_to_model.py +535 -0
  30. onnxruntime_extensions/tools/pre_post_processing/__init__.py +4 -0
  31. onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py +395 -0
  32. onnxruntime_extensions/tools/pre_post_processing/step.py +227 -0
  33. onnxruntime_extensions/tools/pre_post_processing/steps/__init__.py +6 -0
  34. onnxruntime_extensions/tools/pre_post_processing/steps/general.py +366 -0
  35. onnxruntime_extensions/tools/pre_post_processing/steps/nlp.py +344 -0
  36. onnxruntime_extensions/tools/pre_post_processing/steps/vision.py +1157 -0
  37. onnxruntime_extensions/tools/pre_post_processing/utils.py +139 -0
  38. onnxruntime_extensions/util.py +186 -0
  39. onnxruntime_extensions-0.14.0.dist-info/LICENSE +21 -0
  40. onnxruntime_extensions-0.14.0.dist-info/METADATA +102 -0
  41. onnxruntime_extensions-0.14.0.dist-info/RECORD +43 -0
  42. onnxruntime_extensions-0.14.0.dist-info/WHEEL +6 -0
  43. onnxruntime_extensions-0.14.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,310 @@
1
+ import onnx
2
+ import torch
3
+ import numpy as np
4
+ from typing import Any
5
+ from onnx import helper
6
+ from onnx import onnx_pb as onnx_proto
7
+ from distutils.version import LooseVersion
8
+ from torch.onnx import register_custom_op_symbolic
9
+
10
+ from ._utils import ONNXModelUtils
11
+ from ._base import CustomFunction, ProcessingTracedModule, is_processing_module
12
+ from ._onnx_ops import ox as _ox, schema as _schema
13
+ from ._onnx_ops import ONNXElementContainer, make_model_ex
14
+ from .._ortapi2 import OrtPyFunction, get_opset_version_from_ort
15
+
16
+
17
+ def _is_numpy_object(x):
18
+ return isinstance(x, (np.ndarray, np.generic))
19
+
20
+
21
+ def _is_numpy_string_type(arr):
22
+ return arr.dtype.kind in {'U', 'S'}
23
+
24
+
25
+ def _is_string_type(x):
26
+ if isinstance(x, list):
27
+ return any(_is_string_type(e) for e in x)
28
+ elif isinstance(x, torch.Tensor):
29
+ return False
30
+ elif not _is_numpy_object(x):
31
+ x = np.array(x)
32
+ return _is_numpy_string_type(x)
33
+
34
+
35
+ def _to_onnx_type(dtype):
36
+ ty_dict = {torch.bool: onnx_proto.TensorProto.BOOL,
37
+ torch.float32: onnx_proto.TensorProto.FLOAT,
38
+ torch.float64: onnx_proto.TensorProto.DOUBLE,
39
+ torch.long: onnx_proto.TensorProto.INT64,
40
+ torch.int32: onnx_proto.TensorProto.INT32}
41
+ # ...
42
+ return ty_dict.get(dtype, onnx_proto.TensorProto.STRING)
43
+
44
+
45
+ class OnnxOpFunction(CustomFunction):
46
+ @classmethod
47
+ def get_next_id_name(cls, name_base):
48
+ name = 'cls' if name_base is None else name_base
49
+ _cid = getattr(cls, '_cid', 1)
50
+ cls._cid = _cid + 1
51
+ return "{}_{}".format(name, _cid)
52
+
53
+ @staticmethod
54
+ def jvp(ctx: Any, *grad_inputs: Any) -> Any:
55
+ pass
56
+
57
+ @staticmethod
58
+ def backward(ctx: Any, *grad_outputs: Any) -> Any:
59
+ return grad_outputs
60
+
61
+ @classmethod
62
+ def build_model(cls, opset_version, *args):
63
+ # build the one node graph
64
+ if isinstance(args[0], list):
65
+ args = [np.asarray(_i) for _i in args]
66
+ ec = ONNXElementContainer(get_opset_version_from_ort() if opset_version is None else opset_version)
67
+ attrs = cls.attrs
68
+ vi_inputs = [helper.make_tensor_value_info(
69
+ 'it_' + str(id(_arg)), _to_onnx_type(_arg.dtype), list(_arg.shape))
70
+ for _arg in args]
71
+ inputs = [_vi.name for _vi in vi_inputs]
72
+ if hasattr(cls.opb_func, 'outputs') and len(cls.opb_func.outputs) > 0:
73
+ vi_outputs = [helper.make_tensor_value_info(
74
+ cls.get_next_id_name('ot'), *_schm) for _schm in cls.opb_func.outputs]
75
+ else:
76
+ vi_outputs = [helper.make_tensor_value_info(
77
+ cls.get_next_id_name('ot'), onnx_proto.TensorProto.FLOAT, []
78
+ )]
79
+ outputs = [_vi.name for _vi in vi_outputs]
80
+ # build the node
81
+ opfunc = cls.opb_func
82
+ opfunc(inputs, outputs, ec, None, **attrs)
83
+ g = helper.make_graph(ec.nodes, cls.get_next_id_name('g'), vi_inputs, vi_outputs)
84
+ m = make_model_ex(g, ec.node_domain_version_pair_sets, ec.target_opset)
85
+ return m
86
+
87
+ @classmethod
88
+ @torch.jit.unused
89
+ def _onnx_call(cls, ctx, *args) -> Any:
90
+ m = cls.build_model(None, *args)
91
+ try:
92
+ f = OrtPyFunction.from_model(m)
93
+ result = f(*list(_i.numpy() if isinstance(_i, torch.Tensor) else _i for _i in args))
94
+ except Exception as e:
95
+ onnx.save_model(m, '_temp_debugging.onnx')
96
+ raise e
97
+
98
+ results = result if isinstance(result, tuple) else [result]
99
+ return tuple([torch.from_numpy(_o) for _o in results]) if len(results) > 1 else torch.from_numpy(results[0])
100
+
101
+ @classmethod
102
+ def forward(cls, ctx: Any, *args: Any, **kwargs: Any) -> Any:
103
+ return cls._onnx_call(ctx, *args, **kwargs)
104
+
105
+ @classmethod
106
+ def symbolic(cls, g, *args):
107
+ return g.op(cls.op_type, *args)
108
+
109
+
110
+ def create_op_function(op_type: str, func, **attrs):
111
+ if _ox.is_raw(func):
112
+ func = _schema(func.__func__)
113
+ cls = type(_ox.get_unique_operator_type_name(op_type), (OnnxOpFunction,),
114
+ dict(
115
+ op_type=op_type,
116
+ opb_func=func,
117
+ attrs=attrs
118
+ ))
119
+ return cls.apply # noqa
120
+
121
+
122
+ onnx_pad = create_op_function('Pad', _ox.pad)
123
+ onnx_where = create_op_function('Where', _ox.where)
124
+ onnx_greater = create_op_function('Greater', _ox.greater)
125
+
126
+
127
+ class _OnnxModelFunction:
128
+ id_object_map = {} # cannot use the string directly since jit.script doesn't support the data type
129
+ id_function_map = {}
130
+ str_model_function_id = '_model_function_id'
131
+ str_model_id = '_model_id'
132
+ str_model_attached = '_model_attached'
133
+
134
+
135
+ @torch.jit.ignore
136
+ def _invoke_onnx_model(model_id: int, *args, **kwargs):
137
+ func = _OnnxModelFunction.id_function_map.get(model_id, None)
138
+ if not func:
139
+ model_or_path = _OnnxModelFunction.id_object_map.get(model_id)
140
+ if model_or_path is None:
141
+ raise ValueError("cannot find id={} registered!".format(model_id))
142
+ func = OrtPyFunction.from_model(model_or_path)
143
+ _OnnxModelFunction.id_function_map[model_id] = func
144
+ results = func(*list(_i.numpy() if isinstance(_i, torch.Tensor) else _i for _i in args), **kwargs)
145
+ return tuple(
146
+ [torch.from_numpy(_o) for _o in results]) if isinstance(results, tuple) else torch.from_numpy(results)
147
+
148
+
149
+ @torch.jit.ignore
150
+ def invoke_onnx_model1(model_id: int, arg0):
151
+ return _invoke_onnx_model(model_id, arg0)
152
+
153
+
154
+ @torch.jit.ignore
155
+ def invoke_onnx_model2(model_id: int, arg0, arg1):
156
+ return _invoke_onnx_model(model_id, arg0, arg1)
157
+
158
+
159
+ @torch.jit.ignore
160
+ def invoke_onnx_model3(model_id: int, arg0, arg1, arg2):
161
+ return _invoke_onnx_model(model_id, arg0, arg1, arg2)
162
+
163
+
164
+ class _OnnxTracedFunction(CustomFunction):
165
+ @classmethod
166
+ def forward(cls, ctx: Any, *args: Any, **kwargs: Any) -> Any:
167
+ return _invoke_onnx_model(args[0].item(), *args[1:], **kwargs)
168
+
169
+ @classmethod
170
+ def symbolic(cls, g, *args):
171
+ ret = g.op('ai.onnx.contrib::_ModelFunctionCall', *args)
172
+ model_id = torch.onnx.symbolic_helper._maybe_get_scalar(args[0]) # noqa
173
+ if not model_id:
174
+ return ret
175
+
176
+ func = _OnnxModelFunction.id_function_map.get(model_id.item(), None)
177
+ if not func or len(func.outputs) <= 1:
178
+ return ret
179
+
180
+ outputs = [ret]
181
+ for _ in range(len(func.outputs) - 1):
182
+ outputs.append(ret.node().addOutput())
183
+
184
+ return tuple(outputs)
185
+
186
+
187
+ def create_model_function(model_or_path):
188
+ _id = id(model_or_path)
189
+ assert _id != 0, "internal error: the id of a Python object is 0."
190
+ _OnnxModelFunction.id_object_map[_id] = model_or_path
191
+ return _id
192
+
193
+
194
+ def get_id_models():
195
+ return _OnnxModelFunction.id_object_map
196
+
197
+
198
+ class OnnxTracedModelFunction:
199
+ def __init__(self, onnx_model):
200
+ self.func_id = create_model_function(onnx_model)
201
+
202
+ def __call__(self, *args, **kwargs):
203
+ return _OnnxTracedFunction.apply(torch.tensor(self.func_id), *args, **kwargs)
204
+
205
+
206
+ class _OnnxModelModule(torch.nn.Module):
207
+ def __init__(self, mdl):
208
+ super(_OnnxModelModule, self).__init__()
209
+ self.function = OnnxTracedModelFunction(mdl)
210
+
211
+ def forward(self, *args):
212
+ return self.function(*args)
213
+
214
+
215
+ def _symbolic_pythonop(g: torch._C.Graph, n: torch._C.Node, *args, **kwargs):
216
+ name = kwargs["name"]
217
+ if name.startswith(invoke_onnx_model1.__name__[:-1]):
218
+ # NB: if you want to get the value of the first argument, i.e. the model id,
219
+ # you can get it by torch.onnx.symbolic_helper._maybe_get_scalar(args[0]).item()
220
+ ret = g.op("ai.onnx.contrib::_ModelFunctionCall", *args)
221
+ else:
222
+ # Logs a warning and returns None
223
+ import warnings
224
+ return warnings.warn("prim::PythonOp", "unknown node kind: " + name)
225
+ # Copy type and shape from original node.
226
+ ret.setType(n.output().type())
227
+ return ret
228
+
229
+
230
+ if LooseVersion(torch.__version__) >= LooseVersion("1.11"):
231
+ register_custom_op_symbolic("prim::PythonOp", _symbolic_pythonop, 1)
232
+
233
+
234
+ class SequentialProcessingModule(ProcessingTracedModule):
235
+ def __init__(self, *models):
236
+ super(SequentialProcessingModule, self).__init__()
237
+ self.model_list = torch.nn.ModuleList()
238
+ for mdl_ in models:
239
+ if isinstance(mdl_, onnx.ModelProto):
240
+ self.model_list.append(_OnnxModelModule(mdl_))
241
+ elif is_processing_module(mdl_):
242
+ self.model_list.append(mdl_)
243
+ else:
244
+ assert callable(mdl_), "the model type is not recognizable."
245
+ self.model_list.append(ProcessingTracedModule(mdl_))
246
+
247
+ def forward(self, *args):
248
+ outputs = args
249
+ with torch.no_grad():
250
+ for idx_, mdl_ in enumerate(self.model_list):
251
+ if not isinstance(outputs, tuple):
252
+ outputs = (outputs,)
253
+ outputs = mdl_(*outputs)
254
+
255
+ return outputs
256
+
257
+ def export(self, *args, **kwargs):
258
+ prefix_m = None
259
+ core_m = self
260
+ raw_input_flag = any(_is_string_type(x_) for x_ in args)
261
+ if raw_input_flag:
262
+ # NB: torch.onnx.export doesn't support exporting a module accepting string type input,
263
+ # So, in this case, the module will be separated into two parts to use the customized export.
264
+ m0 = self.model_list[0]
265
+ new_args = m0(*args)
266
+ if not isinstance(new_args, tuple):
267
+ new_args = (new_args, )
268
+ prefix_m = m0.export(*args, **kwargs)
269
+ args = new_args
270
+ core_m = SequentialProcessingModule(*self.model_list[1:])
271
+ if prefix_m is None:
272
+ return super().export(*args, **kwargs)
273
+ else:
274
+ oxml = core_m.export(*args, **kwargs)
275
+ model = ONNXModelUtils.join_models(prefix_m, oxml)
276
+
277
+ # Rename the input/output node names if the user has provided any substitutions!
278
+ # Ref: https://github.com/onnx/onnx/issues/2052
279
+ # Known issue: This logic doesn't deal with subgraphs.
280
+ if (('input_names' in kwargs) or ('output_names' in kwargs)) and \
281
+ (kwargs['input_names'] or kwargs['output_names']):
282
+ swaps = {}
283
+ if 'input_names' in kwargs and kwargs['input_names']:
284
+ assert len(model.graph.input) == len(kwargs['input_names']), \
285
+ "Expecting {} input names but got {}".format(
286
+ len(model.graph.input), len(kwargs['input_names']))
287
+ for n, new_name in zip(model.graph.input, kwargs['input_names']):
288
+ swaps[n.name] = new_name
289
+ n.name = new_name
290
+
291
+ if 'output_names' in kwargs and kwargs['output_names']:
292
+ assert len(model.graph.output) == len(kwargs['output_names']), \
293
+ "Expecting {} output names but got {}".format(
294
+ len(model.graph.output), len(kwargs['output_names']))
295
+ for n, new_name in zip(model.graph.output, kwargs['output_names']):
296
+ swaps[n.name] = new_name
297
+ n.name = new_name
298
+
299
+ if swaps:
300
+ for n in model.graph.node:
301
+ for j in range(len(n.input)):
302
+ n.input[j] = swaps.get(n.input[j], n.input[j])
303
+
304
+ for j in range(len(n.output)):
305
+ n.output[j] = swaps.get(n.output[j], n.output[j])
306
+
307
+ for n in model.graph.initializer:
308
+ n.name = swaps.get(n.name, n.name)
309
+
310
+ return model
@@ -0,0 +1,45 @@
1
+ import onnx
2
+
3
+ from .._ortapi2 import get_opset_version_from_ort
4
+ from ._utils import ONNXModelUtils
5
+ from ._base import is_processing_module
6
+ from ._torchext import get_id_models, SequentialProcessingModule
7
+
8
+
9
+ def export(m, *args,
10
+ opset_version=0,
11
+ output_path=None,
12
+ export_params=True,
13
+ verbose=False,
14
+ input_names=None,
15
+ output_names=None,
16
+ operator_export_type=None,
17
+ do_constant_folding=True,
18
+ dynamic_axes=None,
19
+ keep_initializers_as_inputs=None,
20
+ custom_opsets=None,
21
+ io_mapping=None):
22
+ """
23
+ export all models and modules into a merged ONNX model.
24
+ """
25
+ if opset_version == 0:
26
+ opset_version = get_opset_version_from_ort()
27
+
28
+ if not is_processing_module(m):
29
+ m = SequentialProcessingModule(m)
30
+
31
+ model = m.export(*args, opset_version=opset_version,
32
+ output_path=output_path,
33
+ export_params=export_params,
34
+ verbose=verbose,
35
+ input_names=input_names,
36
+ output_names=output_names,
37
+ operator_export_type=operator_export_type,
38
+ do_constant_folding=do_constant_folding,
39
+ dynamic_axes=dynamic_axes,
40
+ keep_initializers_as_inputs=keep_initializers_as_inputs,
41
+ custom_opsets=custom_opsets)
42
+ full_m = ONNXModelUtils.unfold_model(model, get_id_models(), io_mapping)
43
+ if output_path is not None:
44
+ onnx.save_model(full_m, output_path)
45
+ return full_m
@@ -0,0 +1,302 @@
1
+ import copy
2
+ import onnx
3
+ from onnx import helper, numpy_helper
4
+ from collections import namedtuple
5
+
6
+
7
+ class _Container:
8
+ def __init__(self):
9
+ self.parent = None
10
+ self.initializer=[]
11
+ self.value_info=[]
12
+ self.nodes = []
13
+ self.node_domain_version_pair_sets = {}
14
+
15
+ def add_model(self, oxml):
16
+ self.initializer.extend(oxml.graph.initializer)
17
+ self.value_info.extend(oxml.graph.value_info)
18
+ self.nodes.extend(oxml.graph.node)
19
+ self.node_domain_version_pair_sets.update(
20
+ [(opset_.domain, opset_.version) for opset_ in oxml.opset_import])
21
+ return self
22
+
23
+
24
+ class ONNXModelUtils:
25
+ @staticmethod
26
+ def merge_name(prefix, name):
27
+ return "{}_{}".format(prefix, name)
28
+
29
+ @staticmethod
30
+ def _rename_iter(iterables, prefix_name, inplace=False):
31
+ new_iz = iterables if inplace else [copy.deepcopy(iz_) for iz_ in iterables]
32
+ for iz_ in new_iz:
33
+ iz_.name = ONNXModelUtils.merge_name(prefix_name, iz_.name)
34
+ return new_iz
35
+
36
+ @classmethod
37
+ def _rename_graph(cls, graph, prefix, graph_or_container):
38
+ def io_rename(node, prefix_name, idx):
39
+ new_node = copy.deepcopy(node)
40
+ if not node.name:
41
+ new_node.name = cls.merge_name(prefix_name, "op{}".format(idx))
42
+ else:
43
+ new_node.name = cls.merge_name(prefix_name, node.name)
44
+
45
+ del new_node.input[:]
46
+ new_node.input.extend(ONNXModelUtils.merge_name(prefix_name, nm_) if nm_ else '' for nm_ in node.input)
47
+ del new_node.output[:]
48
+ new_node.output.extend(ONNXModelUtils.merge_name(prefix_name, nm_) if nm_ else '' for nm_ in node.output)
49
+ return new_node
50
+
51
+ assert prefix is not None, 'The graph prefix could not be None'
52
+ graph_or_container.initializer.extend(cls._rename_iter(graph.initializer, prefix))
53
+ graph_or_container.value_info.extend(cls._rename_iter(graph.value_info, prefix))
54
+ return list(io_rename(nd_, prefix, idx_) for idx_, nd_ in enumerate(graph.node))
55
+
56
+ @classmethod
57
+ def _process_node_body(cls, node, prefix):
58
+ if all(attr.name != 'body' for attr in node.attribute):
59
+ return node
60
+
61
+ def _process_attr(attr, prefix_name):
62
+ if attr.name == 'body':
63
+ new_attr = copy.deepcopy(attr)
64
+ del new_attr.g.value_info[:]
65
+ del new_attr.g.node[:]
66
+ new_attr.g.node.extend(cls._rename_graph(attr.g, prefix_name, new_attr.g))
67
+ cls._rename_iter(new_attr.g.input, prefix_name, inplace=True)
68
+ cls._rename_iter(new_attr.g.output, prefix_name, inplace=True)
69
+ return new_attr
70
+ else:
71
+ return attr
72
+
73
+ attr_list = list(_process_attr(attr_, prefix) for attr_ in node.attribute)
74
+ del node.attribute[:]
75
+ node.attribute.extend(attr_list)
76
+ return node
77
+
78
+ @staticmethod
79
+ def get_model_name_abbr(node):
80
+ no = node.name.split('_')[-1]
81
+ return 'm_' + no
82
+
83
+ @staticmethod
84
+ def get_model_id_from_arg0(nodes, node):
85
+ arg0_name = node.input[0]
86
+ c_node = [n_ for n_ in nodes if
87
+ n_.op_type == 'Constant' and n_.output[0] == arg0_name]
88
+ assert len(c_node) == 1, 'internal error, multiple nodes with the same output.'
89
+ c_node = c_node[0]
90
+ tensor_value = onnx.helper.get_attribute_value(c_node.attribute[0])
91
+ _id = numpy_helper.to_array(tensor_value).item()
92
+ return _id
93
+
94
+ @classmethod
95
+ def _unfold_model_node(cls, container, name, model, io_mapping=None):
96
+ top_container = container
97
+ while top_container.parent is not None: # only one opset_import in the model.
98
+ top_container = top_container.parent
99
+
100
+ renamed_nodes = cls._rename_graph(model.graph, name, container)
101
+ onnx_nodes = [cls._process_node_body(nd_, name) for nd_ in renamed_nodes]
102
+
103
+ top_container.node_domain_version_pair_sets.update(
104
+ [(opset_.domain, opset_.version) for opset_ in model.opset_import])
105
+ return onnx_nodes
106
+
107
+ @classmethod
108
+ def unfold_model(cls, oxml, id_to_model, io_mapping=None):
109
+ container = _Container().add_model(oxml)
110
+ nodes = []
111
+ for _nid, _node in enumerate(oxml.graph.node):
112
+ if _node.op_type != '_ModelFunctionCall':
113
+ nodes.append(_node)
114
+ else:
115
+ model_id = cls.get_model_id_from_arg0(list(oxml.graph.node), _node)
116
+ if model_id not in id_to_model:
117
+ raise RuntimeError("Cannot find the model id({}) in the table".format(model_id))
118
+
119
+ prefix = cls.get_model_name_abbr(_node)
120
+ nest_model = id_to_model[model_id]
121
+
122
+ input_mapping = []
123
+ output_mapping = []
124
+ for idx_, in_ in enumerate(nest_model.graph.input):
125
+ _renamed_in = "{}_{}".format(prefix, in_.name)
126
+ _nd = onnx.helper.make_node('Identity',
127
+ [_node.input[idx_ + 1]], # the first arg is model id, skip it.
128
+ [_renamed_in],
129
+ name='i_' + _renamed_in)
130
+ input_mapping.append(_nd)
131
+ nds = cls._unfold_model_node(container,
132
+ prefix,
133
+ nest_model,
134
+ io_mapping)
135
+ for idx_, out_ in enumerate(nest_model.graph.output):
136
+ if idx_ >= len(_node.output):
137
+ continue
138
+ _renamed_out = "{}_{}".format(prefix, out_.name)
139
+ _nd = onnx.helper.make_node('Identity',
140
+ [_renamed_out],
141
+ [_node.output[idx_]],
142
+ name='o_' + _renamed_out)
143
+ output_mapping.append(_nd)
144
+ if io_mapping is not None:
145
+ assert callable(io_mapping), "io_mapping is a custom function to build the linkage of the models"
146
+ input_mapping, output_mapping = io_mapping(input_mapping, output_mapping)
147
+ # attention: the order of the list operations is important, which avoids the topological sort.
148
+ nodes.extend(input_mapping)
149
+ nodes.extend(nds)
150
+ nodes.extend(output_mapping)
151
+
152
+ intlzs = cls._remove_unused_initializers(nodes, container.initializer)
153
+ oxml = copy.deepcopy(oxml)
154
+ del oxml.graph.node[:]
155
+ oxml.graph.node.extend(nodes)
156
+ del oxml.graph.initializer[:]
157
+ oxml.graph.initializer.extend(intlzs)
158
+ return oxml
159
+
160
+ @classmethod
161
+ def topological_sort(cls, container, nodes, inputs, outputs):
162
+ op_output_map = {}
163
+ DynNode = namedtuple('DynNode', ['name', 'output'])
164
+ input_nodes = [DynNode(name='placeholder',
165
+ output=[nm_.name for nm_ in inputs] +
166
+ [it_.name for it_ in container.initializers])] + \
167
+ [nd_ for nd_ in nodes if nd_.op_type == 'Constant']
168
+
169
+ for nd_ in nodes + input_nodes:
170
+ for ky_ in nd_.output:
171
+ op_output_map[ky_] = nd_
172
+
173
+ edges = {}
174
+ for op in nodes:
175
+ for x in op.input:
176
+ if x == '':
177
+ continue
178
+ try:
179
+ predecessor = op_output_map[x]
180
+ except KeyError:
181
+ raise RuntimeError(
182
+ "{}: cannot find an operator to produce the tensor: {}".format(op.name, x)) from None
183
+
184
+ val = edges.get(predecessor.name, [])
185
+ val.append(op)
186
+ edges[predecessor.name] = val
187
+
188
+ for y_ in outputs:
189
+ op = op_output_map[y_.name].name
190
+ if op not in edges:
191
+ edges[op] = []
192
+
193
+ visited = set()
194
+ sorted_nodes = []
195
+ unfinished_nodes = set()
196
+
197
+ def recursive_helper(node):
198
+ if node.name in visited:
199
+ return
200
+
201
+ if node.name in unfinished_nodes:
202
+ raise RuntimeError("ONNX Graph is not a DAG, the cycle is found at {}".format(node.name))
203
+
204
+ unfinished_nodes.add(node.name)
205
+ if node.name in edges: # if the node's output is not in the Graph output.
206
+ assert node.name != '', 'this topological-sort depends on the unique node name.'
207
+ for successor in edges[node.name]:
208
+ recursive_helper(successor)
209
+
210
+ unfinished_nodes.remove(node.name)
211
+ visited.add(node.name)
212
+ if node is not input_nodes[0]:
213
+ sorted_nodes.insert(0, node)
214
+
215
+ for nd_ in input_nodes:
216
+ recursive_helper(nd_)
217
+
218
+ return sorted_nodes
219
+
220
+ @staticmethod
221
+ def _remove_unused_initializers(nodes, initializers, reserved_names=None):
222
+ if reserved_names is None:
223
+ reserved_names = set()
224
+ nodes_input_set = set()
225
+ for nd_ in nodes:
226
+ nodes_input_set.update(n_ for n_ in nd_.input)
227
+
228
+ return [intlz_ for intlz_ in initializers if intlz_.name in nodes_input_set or intlz_.name in reserved_names]
229
+
230
+ @classmethod
231
+ def join_models(cls, *models, io_mapping=None):
232
+ # generate the prefix id for the embedding graph to avoid the name conflict
233
+ mdl_prefix = []
234
+ for _i in range(len(models)):
235
+ mdl_prefix.append("g{}".format(_i + 1))
236
+
237
+ inputs = cls._rename_iter(models[0].graph.input, mdl_prefix[0])
238
+ outputs = cls._rename_iter(models[-1].graph.output, mdl_prefix[-1])
239
+
240
+ port_mapping = {}
241
+ if io_mapping is not None:
242
+ assert callable(io_mapping), "io_mapping is a custom function to build the linkage of the models"
243
+ ModelPort = namedtuple('ModelPort', "input output")
244
+ ports = []
245
+ for _idx in range(len(models)):
246
+ mio = ModelPort([cls.merge_name(mdl_prefix[_idx], _x.name) for _x in models[_idx].graph.input],
247
+ [cls.merge_name(mdl_prefix[_idx], _y.name) for _y in models[_idx].graph.output])
248
+ ports.append(mio)
249
+ port_mapping = io_mapping(ports)
250
+ for _idx in range(len(models) - 1):
251
+ for _i, _x in enumerate(models[_idx + 1].graph.input):
252
+ iname = cls.merge_name(mdl_prefix[_idx + 1], _x.name)
253
+ if iname not in port_mapping:
254
+ oname = cls.merge_name(mdl_prefix[_idx], models[_idx].graph.output[_i].name)
255
+ port_mapping[iname] = oname
256
+
257
+ nodes = []
258
+ container = _Container()
259
+ for _idx, _m in enumerate(models):
260
+ container.add_model(_m)
261
+ nodes += cls._rename_graph(_m.graph, mdl_prefix[_idx], container)
262
+
263
+ for _n in nodes:
264
+ replaceable = False
265
+ for _i in _n.input:
266
+ if _i in port_mapping:
267
+ replaceable = True
268
+ break
269
+ if replaceable:
270
+ new_input = copy.deepcopy(_n.input)
271
+ del _n.input[:]
272
+ _n.input.extend([port_mapping[_i] if _i in port_mapping else _i for _i in new_input])
273
+
274
+ name = "_".join([_mdl.graph.name for _mdl in models])
275
+ domains = set()
276
+ _opset = []
277
+ for _mdl in models:
278
+ for _ops in _mdl.opset_import:
279
+ domain = _ops.domain if _ops.domain else "ai.onnx"
280
+ if domain in domains:
281
+ if domain == "ai.onnx":
282
+ assert _ops.version == _opset[0].version, \
283
+ f"ai.onnx domain version doesn't match {_ops.version} != {_opset[0].version}"
284
+ else:
285
+ domains.add(domain)
286
+ if domain == "ai.onnx":
287
+ _opset.insert(0, _ops)
288
+ else:
289
+ _opset.append(_ops)
290
+
291
+ inits = cls._remove_unused_initializers(nodes, container.initializer)
292
+ g = helper.make_graph(nodes, name, inputs, outputs,
293
+ initializer=inits,
294
+ value_info=container.value_info)
295
+
296
+ if hasattr(helper, 'make_model_gen_version'):
297
+ # make_model_gen_version doesn't accept the custom domain.
298
+ m = helper.make_model_gen_version(g, opset_imports=_opset[:1])
299
+ m.opset_import.extend(_opset[1:])
300
+ else:
301
+ m = helper.make_model(g, opset_imports=_opset)
302
+ return m