onnxruntime_extensions 0.14.0__cp313-cp313-macosx_11_0_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime_extensions/__init__.py +82 -0
- onnxruntime_extensions/_cuops.py +564 -0
- onnxruntime_extensions/_extensions_pydll.cpython-313-darwin.so +0 -0
- onnxruntime_extensions/_extensions_pydll.pyi +45 -0
- onnxruntime_extensions/_hf_cvt.py +331 -0
- onnxruntime_extensions/_ocos.py +133 -0
- onnxruntime_extensions/_ortapi2.py +274 -0
- onnxruntime_extensions/_torch_cvt.py +231 -0
- onnxruntime_extensions/_version.py +2 -0
- onnxruntime_extensions/cmd.py +66 -0
- onnxruntime_extensions/cvt.py +306 -0
- onnxruntime_extensions/onnxprocess/__init__.py +12 -0
- onnxruntime_extensions/onnxprocess/_builder.py +53 -0
- onnxruntime_extensions/onnxprocess/_onnx_ops.py +1507 -0
- onnxruntime_extensions/onnxprocess/_session.py +355 -0
- onnxruntime_extensions/onnxprocess/_tensor.py +628 -0
- onnxruntime_extensions/onnxprocess/torch_wrapper.py +31 -0
- onnxruntime_extensions/pnp/__init__.py +13 -0
- onnxruntime_extensions/pnp/_base.py +124 -0
- onnxruntime_extensions/pnp/_imagenet.py +65 -0
- onnxruntime_extensions/pnp/_nlp.py +148 -0
- onnxruntime_extensions/pnp/_onnx_ops.py +1544 -0
- onnxruntime_extensions/pnp/_torchext.py +310 -0
- onnxruntime_extensions/pnp/_unifier.py +45 -0
- onnxruntime_extensions/pnp/_utils.py +302 -0
- onnxruntime_extensions/pp_api.py +83 -0
- onnxruntime_extensions/tools/__init__.py +0 -0
- onnxruntime_extensions/tools/add_HuggingFace_CLIPImageProcessor_to_model.py +171 -0
- onnxruntime_extensions/tools/add_pre_post_processing_to_model.py +535 -0
- onnxruntime_extensions/tools/pre_post_processing/__init__.py +4 -0
- onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py +395 -0
- onnxruntime_extensions/tools/pre_post_processing/step.py +227 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/__init__.py +6 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/general.py +366 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/nlp.py +344 -0
- onnxruntime_extensions/tools/pre_post_processing/steps/vision.py +1157 -0
- onnxruntime_extensions/tools/pre_post_processing/utils.py +139 -0
- onnxruntime_extensions/util.py +186 -0
- onnxruntime_extensions-0.14.0.dist-info/LICENSE +21 -0
- onnxruntime_extensions-0.14.0.dist-info/METADATA +102 -0
- onnxruntime_extensions-0.14.0.dist-info/RECORD +43 -0
- onnxruntime_extensions-0.14.0.dist-info/WHEEL +6 -0
- onnxruntime_extensions-0.14.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
import onnx
|
|
5
|
+
|
|
6
|
+
from onnx import version_converter
|
|
7
|
+
from typing import List, Tuple, Union
|
|
8
|
+
|
|
9
|
+
from .utils import (
|
|
10
|
+
IoMapEntry,
|
|
11
|
+
IOEntryValuePreserver,
|
|
12
|
+
create_custom_op_checker_context,
|
|
13
|
+
sanitize_output_names,
|
|
14
|
+
TENSOR_TYPE_TO_ONNX_TYPE,
|
|
15
|
+
)
|
|
16
|
+
from .step import Step
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PrePostProcessor:
|
|
20
|
+
"""
|
|
21
|
+
Class to handle running all the pre/post processing steps and updating the model.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, inputs: List[onnx.ValueInfoProto] = None, onnx_opset: int = 16):
|
|
25
|
+
"""
|
|
26
|
+
Create a PrePostProcessor instance.
|
|
27
|
+
|
|
28
|
+
Args:
|
|
29
|
+
inputs: The inputs the model will use if pre-processing is added.
|
|
30
|
+
onnx_opset: The ONNX opset to use.
|
|
31
|
+
Minimum is 16. 18 or higher is strongly preferred if image resizing is involved due to its
|
|
32
|
+
anti-aliasing ability.
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
if onnx_opset < 16:
|
|
36
|
+
raise ValueError("ONNX opset must be 16 or later.")
|
|
37
|
+
|
|
38
|
+
self._onnx_opset = onnx_opset
|
|
39
|
+
self._custom_op_checker_context = create_custom_op_checker_context(onnx_opset)
|
|
40
|
+
|
|
41
|
+
self.pre_processors = []
|
|
42
|
+
self.post_processors = []
|
|
43
|
+
|
|
44
|
+
# Connections for each pre/post processor. 1:1 mapping with entries in pre_processors/post_processors
|
|
45
|
+
self._pre_processor_connections = [] # type: List[List[IoMapEntry]]
|
|
46
|
+
self._post_processor_connections = [] # type: List[List[IoMapEntry]]
|
|
47
|
+
|
|
48
|
+
# explicitly join outputs from Steps in pre_processors to inputs of the original model
|
|
49
|
+
# format is Step or step name, step_idx, name of graph input/output
|
|
50
|
+
#
|
|
51
|
+
# Pre-processing we connect Step output to original model:
|
|
52
|
+
# - step_idx is for Step.output_names, and name is in graph.input
|
|
53
|
+
#
|
|
54
|
+
# Post-processing we connect the original model output to the Step input
|
|
55
|
+
# - step_idx is for Step.input_names, and name is in graph.output
|
|
56
|
+
self._pre_processing_joins = None # type: Union[None,List[Tuple[Union[Step, str], int, str]]]
|
|
57
|
+
self._post_processing_joins = None # type: Union[None,List[Tuple[Union[Step, str], int, str]]]
|
|
58
|
+
|
|
59
|
+
self._inputs = inputs if inputs else []
|
|
60
|
+
|
|
61
|
+
# preserve outputs from IOMapEntry, avoid it's consumed by the Follow-up steps.
|
|
62
|
+
# we now can support a output value has more than one consumers with IOEntryValuePreserver.
|
|
63
|
+
# IOEntryValuePreserver will preserve the output value and add it to the graph output
|
|
64
|
+
# until consumer step is done.
|
|
65
|
+
self.outputs_preserver = [] # type: List[IOEntryValuePreserver]
|
|
66
|
+
|
|
67
|
+
def add_pre_processing(self, items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]]):
|
|
68
|
+
"""
|
|
69
|
+
Add the pre-processing steps. The last step is automatically joined to the original model inputs.
|
|
70
|
+
|
|
71
|
+
Options are:
|
|
72
|
+
Add Step with default connection of outputs from the previous step (if available) to inputs of this step.
|
|
73
|
+
Add tuple of Step and list of IoMapEntry instances for manual connections to previous steps. This will be
|
|
74
|
+
used to override any automatic connections.
|
|
75
|
+
If IoMapEntry.producer is None it is inferred to be the immediately previous Step.
|
|
76
|
+
If IoMapEntry.producer is a step name it must match the name of a previous step.
|
|
77
|
+
"""
|
|
78
|
+
self.__add_processing(self.pre_processors, self._pre_processor_connections, items)
|
|
79
|
+
|
|
80
|
+
def add_post_processing(self, items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]]):
|
|
81
|
+
"""
|
|
82
|
+
Add the post-processing steps. The first step is automatically joined to the original model outputs.
|
|
83
|
+
|
|
84
|
+
Options are:
|
|
85
|
+
Add Step with default connection of outputs from the previous step (if available) to inputs of this step.
|
|
86
|
+
Add tuple of Step and list of IoMapEntry instances for connections to previous steps. This will be
|
|
87
|
+
used to override any automatic connections.
|
|
88
|
+
If IoMapEntry.producer is None it is inferred to be the immediately previous Step.
|
|
89
|
+
If IoMapEntry.producer is a step name it must match the name of a previous step.
|
|
90
|
+
"""
|
|
91
|
+
self.__add_processing(self.post_processors, self._post_processor_connections, items)
|
|
92
|
+
|
|
93
|
+
def _add_connection(self, consumer: Step, entry: IoMapEntry):
|
|
94
|
+
producer = self.__producer_from_step_or_str(entry.producer)
|
|
95
|
+
|
|
96
|
+
# Black does annoying things with the multi-line 'if' conditions making the code far less readable
|
|
97
|
+
# fmt: off
|
|
98
|
+
if not ((producer in self.pre_processors or producer in self.post_processors) and
|
|
99
|
+
(consumer in self.pre_processors or consumer in self.post_processors)):
|
|
100
|
+
raise ValueError("Producer and Consumer processors must both be registered")
|
|
101
|
+
|
|
102
|
+
if producer in self.pre_processors:
|
|
103
|
+
if (consumer in self.pre_processors and
|
|
104
|
+
self.pre_processors.index(producer) > self.pre_processors.index(consumer)):
|
|
105
|
+
raise ValueError("Producer was registered after consumer and cannot be connected")
|
|
106
|
+
elif producer in self.post_processors:
|
|
107
|
+
if consumer not in self.post_processors:
|
|
108
|
+
raise ValueError("Cannot connect pre-processor consumer with post-processor producer")
|
|
109
|
+
elif self.post_processors.index(producer) > self.post_processors.index(consumer):
|
|
110
|
+
raise ValueError("Producer was registered after consumer and cannot be connected")
|
|
111
|
+
# fmt: on
|
|
112
|
+
|
|
113
|
+
assert isinstance(producer, Step)
|
|
114
|
+
consumer.connect(entry)
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def run(self, model: onnx.ModelProto):
|
|
118
|
+
"""
|
|
119
|
+
Update the model with the graph from each step in the pre and post processing pipelines.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
model: model to add pre/post processing to.
|
|
123
|
+
|
|
124
|
+
Returns:
|
|
125
|
+
model with pre/post processing in it.
|
|
126
|
+
"""
|
|
127
|
+
|
|
128
|
+
# update the input model to the ONNX opset we're using. this is required as we implement the steps based on
|
|
129
|
+
# the operator specs for this opset.
|
|
130
|
+
model_opset = [
|
|
131
|
+
entry.version for entry in model.opset_import if entry.domain == "" or entry.domain == "ai.onnx"
|
|
132
|
+
][0]
|
|
133
|
+
|
|
134
|
+
if model_opset > self._onnx_opset:
|
|
135
|
+
# It will probably work if the user updates PRE_POST_PROCESSING_ONNX_OPSET to match the model
|
|
136
|
+
# but there are no guarantees.
|
|
137
|
+
# Would only break if ONNX operators used in the pre/post processing graphs have had spec changes.
|
|
138
|
+
raise ValueError(f"Model opset is {model_opset} which is newer than the opset used by this script.")
|
|
139
|
+
elif model_opset < self._onnx_opset:
|
|
140
|
+
model = onnx.version_converter.convert_version(model, self._onnx_opset)
|
|
141
|
+
|
|
142
|
+
def name_nodes(new_graph: onnx.GraphProto, prefix: str):
|
|
143
|
+
# simple helper so all nodes are named. this makes it far easier to debug any issues.
|
|
144
|
+
idx = 0
|
|
145
|
+
for n in new_graph.node:
|
|
146
|
+
if not n.name:
|
|
147
|
+
n.name = prefix + str(idx)
|
|
148
|
+
idx += 1
|
|
149
|
+
|
|
150
|
+
def preserved_apply(processor: Step, *args):
|
|
151
|
+
# Trying to activate the IOEntryValuePreserver and preserve outputs.
|
|
152
|
+
# and deactivate the outputs when the current graph consumed them
|
|
153
|
+
|
|
154
|
+
for preserver in self.outputs_preserver:
|
|
155
|
+
if preserver.consumer == processor:
|
|
156
|
+
preserver.is_active = False
|
|
157
|
+
|
|
158
|
+
# IOEntryValuePreserver, preserve those outputs which has multiple consumers.
|
|
159
|
+
# we explicitly add the output to the graph output.
|
|
160
|
+
graph_outputs_to_maintain = [i.output for i in self.outputs_preserver if i.is_active]
|
|
161
|
+
graph_for_step = processor.apply(*args, graph_outputs_to_maintain=graph_outputs_to_maintain)
|
|
162
|
+
|
|
163
|
+
for preserver in self.outputs_preserver:
|
|
164
|
+
if preserver.producer == processor:
|
|
165
|
+
preserver.is_active = True
|
|
166
|
+
preserver.output = processor.output_names[preserver.producer_idx]
|
|
167
|
+
return graph_for_step
|
|
168
|
+
|
|
169
|
+
def connect_and_run(graph: onnx.GraphProto, processor: Step, connections: List[IoMapEntry]):
|
|
170
|
+
for connection in connections:
|
|
171
|
+
assert connection.producer
|
|
172
|
+
self._add_connection(processor, connection)
|
|
173
|
+
|
|
174
|
+
return preserved_apply(processor, graph, self._custom_op_checker_context)
|
|
175
|
+
|
|
176
|
+
# fix any invalid output names now if we're adding post-processing as the onnx parse_graph can't handle them
|
|
177
|
+
if self.post_processors:
|
|
178
|
+
sanitize_output_names(model.graph)
|
|
179
|
+
|
|
180
|
+
graph = model.graph
|
|
181
|
+
# add pre-processing
|
|
182
|
+
if self.pre_processors:
|
|
183
|
+
# create empty graph with pass through of the requested input name
|
|
184
|
+
pre_process_graph = onnx.GraphProto()
|
|
185
|
+
for i in self._inputs:
|
|
186
|
+
pre_process_graph.input.append(i)
|
|
187
|
+
pre_process_graph.output.append(i)
|
|
188
|
+
|
|
189
|
+
# connect up the graph input names to the first pre-processing step based on order
|
|
190
|
+
self.pre_processors[0]._connect_graph_inputs([vi.name for vi in self._inputs])
|
|
191
|
+
|
|
192
|
+
for idx, step in enumerate(self.pre_processors):
|
|
193
|
+
pre_process_graph = connect_and_run(pre_process_graph, step, self._pre_processor_connections[idx])
|
|
194
|
+
|
|
195
|
+
# name all the nodes for easier debugging
|
|
196
|
+
name_nodes(pre_process_graph, "pre_process_")
|
|
197
|
+
|
|
198
|
+
if not self._pre_processing_joins:
|
|
199
|
+
# default to 1:1 between outputs of last step with inputs of original model
|
|
200
|
+
last_step = self.pre_processors[-1]
|
|
201
|
+
num_entries = min(len(last_step.output_names), len(graph.input))
|
|
202
|
+
self._pre_processing_joins = [(last_step, i, graph.input[i].name) for i in range(0, num_entries)]
|
|
203
|
+
|
|
204
|
+
# map the pre-processing outputs to graph inputs
|
|
205
|
+
# we may need a natty way to get possible outputs after merge_graphs
|
|
206
|
+
step_graph_outputs = [o.name for o in pre_process_graph.output]
|
|
207
|
+
io_map = [] # type: List[Tuple[str, str]]
|
|
208
|
+
for step, step_idx, graph_input in self._pre_processing_joins:
|
|
209
|
+
io_map.append((step.output_names[step_idx], graph_input))
|
|
210
|
+
step_graph_outputs.remove((step.output_names[step_idx]))
|
|
211
|
+
|
|
212
|
+
# add outputs from previous IoMapEntry producers to maintain them as graph outputs
|
|
213
|
+
# until consumed by the final Step that requires them.
|
|
214
|
+
step_graph_outputs += [
|
|
215
|
+
o.name for o in graph.output if o.name not in step_graph_outputs]
|
|
216
|
+
external_outputs = [
|
|
217
|
+
i.output for i in self.outputs_preserver if i.is_active and i.output not in step_graph_outputs]
|
|
218
|
+
if external_outputs:
|
|
219
|
+
step_graph_outputs.extend(external_outputs)
|
|
220
|
+
graph = onnx.compose.merge_graphs(pre_process_graph, graph, io_map, outputs=step_graph_outputs)
|
|
221
|
+
|
|
222
|
+
# add post-processing
|
|
223
|
+
if self.post_processors:
|
|
224
|
+
orig_model_outputs = [o.name for o in model.graph.output]
|
|
225
|
+
graph_outputs = [o.name for o in graph.output] # this may have additional outputs from pre-processing
|
|
226
|
+
|
|
227
|
+
# create default joins if needed
|
|
228
|
+
if not self._post_processing_joins:
|
|
229
|
+
# default to 1:1 between outputs of original model with inputs of first post-processing step
|
|
230
|
+
first_step = self.post_processors[0]
|
|
231
|
+
num_entries = min(len(first_step.input_names), len(orig_model_outputs))
|
|
232
|
+
self._post_processing_joins = [(first_step, i, orig_model_outputs[i]) for i in range(0, num_entries)]
|
|
233
|
+
|
|
234
|
+
# update the input names for the steps to match the values produced by the model
|
|
235
|
+
for step, step_idx, graph_output in self._post_processing_joins:
|
|
236
|
+
assert graph_output in graph_outputs
|
|
237
|
+
step.input_names[step_idx] = graph_output
|
|
238
|
+
|
|
239
|
+
# create empty graph with the values that will be available to the post-processing
|
|
240
|
+
post_process_graph = onnx.GraphProto()
|
|
241
|
+
for o in graph.output:
|
|
242
|
+
post_process_graph.input.append(o)
|
|
243
|
+
post_process_graph.output.append(o)
|
|
244
|
+
|
|
245
|
+
for idx, step in enumerate(self.post_processors):
|
|
246
|
+
post_process_graph = connect_and_run(post_process_graph, step, self._post_processor_connections[idx])
|
|
247
|
+
|
|
248
|
+
name_nodes(post_process_graph, "post_process_")
|
|
249
|
+
|
|
250
|
+
# io_map should be 1:1 with the post-processing graph given we updated the step input names to match
|
|
251
|
+
io_map = [(o, o) for o in graph_outputs]
|
|
252
|
+
graph = onnx.compose.merge_graphs(graph, post_process_graph, io_map)
|
|
253
|
+
|
|
254
|
+
# Make the output names nicer by removing prefixing from naming that occurred when applying the steps
|
|
255
|
+
graph = PrePostProcessor.__cleanup_graph_output_names(graph)
|
|
256
|
+
|
|
257
|
+
opset_imports = [onnx.helper.make_operatorsetid(domain, opset)
|
|
258
|
+
for domain, opset in self._custom_op_checker_context.opset_imports.items()]
|
|
259
|
+
# find_min_ir_version_for doesn't support custom domains until ONNX 1.14 so extract the ONNX opset from the
|
|
260
|
+
# imports and only pass that in.
|
|
261
|
+
ir_version = onnx.helper.find_min_ir_version_for([entry for entry in opset_imports
|
|
262
|
+
if entry.domain == "" or entry.domain == "ai.onnx"])
|
|
263
|
+
new_model = onnx.helper.make_model(graph, opset_imports=opset_imports, ir_version=ir_version)
|
|
264
|
+
|
|
265
|
+
onnx.checker.check_model(new_model)
|
|
266
|
+
|
|
267
|
+
return new_model
|
|
268
|
+
|
|
269
|
+
def __add_processing(
|
|
270
|
+
self,
|
|
271
|
+
processors: List[Step],
|
|
272
|
+
processor_connections: List[List[IoMapEntry]],
|
|
273
|
+
items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]],
|
|
274
|
+
):
|
|
275
|
+
"""
|
|
276
|
+
Add the pre/post processing steps and join with existing steps.
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
processors: List of processors to add items to.
|
|
280
|
+
processor_connections: Populated with connections between each step. 1:1 with entries in processors.
|
|
281
|
+
items: Items to add to processors.
|
|
282
|
+
Can be:
|
|
283
|
+
A Step instance. This will be implicitly joined to the immediately previous Step if one exists.
|
|
284
|
+
A tuple of (Step instance, list of IoMapEntry)
|
|
285
|
+
The IoMapEntry values are used to manually join an output from a producer Step to an input
|
|
286
|
+
of the current Step.
|
|
287
|
+
In each IoMapEntry, if a step name is provided the producer Step will be searched for in all
|
|
288
|
+
predecessor steps. It is valid for a post-processor step to consume output from a
|
|
289
|
+
pre-processor step.
|
|
290
|
+
"""
|
|
291
|
+
|
|
292
|
+
for item in items:
|
|
293
|
+
step = None
|
|
294
|
+
explicit_io_map_entries = None
|
|
295
|
+
|
|
296
|
+
if isinstance(item, Step):
|
|
297
|
+
step = item
|
|
298
|
+
elif isinstance(item, tuple):
|
|
299
|
+
step, explicit_io_map_entries = item
|
|
300
|
+
else:
|
|
301
|
+
raise ValueError("Unexpected type " + str(type(item)))
|
|
302
|
+
|
|
303
|
+
# start with implicit joins and replace with explicitly provided ones
|
|
304
|
+
# this allows the user to specify the minimum number of manual joins.
|
|
305
|
+
io_map_entries = [None] * len(step.input_names) # type: List[Union[None,IoMapEntry]]
|
|
306
|
+
prev_step = None if len(processors) == 0 else processors[-1]
|
|
307
|
+
if prev_step:
|
|
308
|
+
# default is connecting as many outputs from the previous step as possible
|
|
309
|
+
for i in range(0, min(len(prev_step.output_names), len(step.input_names))):
|
|
310
|
+
io_map_entries[i] = IoMapEntry(prev_step, i, i)
|
|
311
|
+
|
|
312
|
+
# add explicit connections
|
|
313
|
+
if explicit_io_map_entries:
|
|
314
|
+
for entry in explicit_io_map_entries:
|
|
315
|
+
if not entry.producer:
|
|
316
|
+
producer = prev_step
|
|
317
|
+
else:
|
|
318
|
+
producer = self.__producer_from_step_or_str(entry.producer) # throws if not found
|
|
319
|
+
|
|
320
|
+
io_map_entries[entry.consumer_idx] = IoMapEntry(producer, entry.producer_idx, entry.consumer_idx)
|
|
321
|
+
self.outputs_preserver.append(IOEntryValuePreserver(producer, step, entry.producer_idx))
|
|
322
|
+
|
|
323
|
+
processors.append(step)
|
|
324
|
+
processor_connections.append([entry for entry in io_map_entries if entry is not None])
|
|
325
|
+
|
|
326
|
+
def __producer_from_step_or_str(self, entry: Union[Step, str]):
|
|
327
|
+
if isinstance(entry, Step):
|
|
328
|
+
return entry
|
|
329
|
+
if isinstance(entry, str):
|
|
330
|
+
match = (next((s for s in self.pre_processors if s.name == entry), None) or
|
|
331
|
+
next((s for s in self.post_processors if s.name == entry), None)) # fmt: skip
|
|
332
|
+
|
|
333
|
+
if not match:
|
|
334
|
+
raise ValueError(f"Step named {entry} was not found")
|
|
335
|
+
|
|
336
|
+
return match
|
|
337
|
+
|
|
338
|
+
@staticmethod
|
|
339
|
+
def __cleanup_graph_output_names(graph: onnx.GraphProto):
|
|
340
|
+
"""
|
|
341
|
+
Hide the prefixing of names that happens when we merge the graphs from the pre/post processing steps.
|
|
342
|
+
Not essential but makes the graph outputs look far nicer.
|
|
343
|
+
"""
|
|
344
|
+
|
|
345
|
+
# for each output create identity node to remove prefixing
|
|
346
|
+
io_map = []
|
|
347
|
+
fixes = onnx.GraphProto()
|
|
348
|
+
|
|
349
|
+
# manually handle naming clashes
|
|
350
|
+
input_names = set([i.name for i in graph.input])
|
|
351
|
+
used_names = set(input_names)
|
|
352
|
+
conflicts = 0
|
|
353
|
+
|
|
354
|
+
for o in graph.output:
|
|
355
|
+
if not o.name.startswith(Step.prefix):
|
|
356
|
+
continue
|
|
357
|
+
|
|
358
|
+
# we will create a small graph to do the renames so the output of the original graph will be an input
|
|
359
|
+
# to that 'fixer' graph
|
|
360
|
+
io_map.append((o.name, o.name))
|
|
361
|
+
clean_name = o.name
|
|
362
|
+
while clean_name.startswith(Step.prefix):
|
|
363
|
+
# output from last step will have one prefixing stage that adds Step._prefix + '_'
|
|
364
|
+
# e.g. '_ppp8_<orig_name>'
|
|
365
|
+
next_underscore = clean_name.find("_", 1)
|
|
366
|
+
if next_underscore > 0:
|
|
367
|
+
# this check shouldn't be necessary as we always add the trailing '_' when prefixing...
|
|
368
|
+
if len(clean_name) > next_underscore + 1:
|
|
369
|
+
next_underscore += 1
|
|
370
|
+
clean_name = clean_name[next_underscore:]
|
|
371
|
+
|
|
372
|
+
# handle things like super resolution where there's an 'image' input and 'image' output
|
|
373
|
+
if clean_name in input_names:
|
|
374
|
+
clean_name += "_out"
|
|
375
|
+
|
|
376
|
+
orig_clean_name = clean_name
|
|
377
|
+
while clean_name in used_names:
|
|
378
|
+
conflicts += 1
|
|
379
|
+
clean_name = f"{orig_clean_name}{conflicts}"
|
|
380
|
+
|
|
381
|
+
used_names.add(clean_name)
|
|
382
|
+
|
|
383
|
+
renamer = onnx.helper.make_node("Identity", [o.name], [clean_name], f"Rename {o.name}")
|
|
384
|
+
fixes.node.append(renamer)
|
|
385
|
+
fixes.input.append(o)
|
|
386
|
+
|
|
387
|
+
new_output = fixes.output.add()
|
|
388
|
+
new_output.name = clean_name
|
|
389
|
+
new_output.type.CopyFrom(o.type)
|
|
390
|
+
|
|
391
|
+
# merge if we have any renaming to do
|
|
392
|
+
if io_map:
|
|
393
|
+
graph = onnx.compose.merge_graphs(graph, fixes, io_map)
|
|
394
|
+
|
|
395
|
+
return graph
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
import abc
|
|
5
|
+
import onnx
|
|
6
|
+
|
|
7
|
+
from onnx import parser
|
|
8
|
+
from typing import List, Optional, Tuple
|
|
9
|
+
|
|
10
|
+
from .utils import (
|
|
11
|
+
IoMapEntry,
|
|
12
|
+
create_custom_op_checker_context,
|
|
13
|
+
TENSOR_TYPE_TO_ONNX_TYPE,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Step(object):
|
|
18
|
+
"""Base class for a pre or post processing step."""
|
|
19
|
+
|
|
20
|
+
prefix = "_ppp"
|
|
21
|
+
_step_num = 0 # unique step number so we can prefix the naming in the graph created for the step
|
|
22
|
+
|
|
23
|
+
def __init__(self, inputs: List[str], outputs: List[str], name: Optional[str] = None):
|
|
24
|
+
"""
|
|
25
|
+
Initialize the step.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
inputs: List of default input names.
|
|
29
|
+
outputs: List of default output names.
|
|
30
|
+
name: Step name. Defaults to the derived class name.
|
|
31
|
+
"""
|
|
32
|
+
self.step_num = Step._step_num
|
|
33
|
+
self.input_names = inputs
|
|
34
|
+
self.output_names = outputs
|
|
35
|
+
self.name = name if name else f"{self.__class__.__name__}"
|
|
36
|
+
self._prefix = f"{Step.prefix}{self.step_num}_"
|
|
37
|
+
|
|
38
|
+
Step._step_num += 1
|
|
39
|
+
|
|
40
|
+
def connect(self, entry: IoMapEntry):
|
|
41
|
+
"""
|
|
42
|
+
Connect the value name from a previous step to an input of this step so they match.
|
|
43
|
+
This makes joining the GraphProto created by each step trivial.
|
|
44
|
+
"""
|
|
45
|
+
assert len(entry.producer.output_names) >= entry.producer_idx
|
|
46
|
+
assert len(self.input_names) >= entry.consumer_idx
|
|
47
|
+
assert isinstance(entry.producer, Step)
|
|
48
|
+
|
|
49
|
+
self.input_names[entry.consumer_idx] = entry.producer.output_names[entry.producer_idx]
|
|
50
|
+
|
|
51
|
+
def _connect_graph_inputs(self, graph_inputs: List[str]):
|
|
52
|
+
"Internal method to connect names of the first pre-processor step with the graph inputs"
|
|
53
|
+
for i, input_name in enumerate(graph_inputs):
|
|
54
|
+
self.input_names[i] = input_name
|
|
55
|
+
|
|
56
|
+
def apply(self, graph: onnx.GraphProto,
|
|
57
|
+
checker_context: onnx.checker.C.CheckerContext,
|
|
58
|
+
graph_outputs_to_maintain: List[str]):
|
|
59
|
+
"""
|
|
60
|
+
Create a graph for this step that can be appended to the provided graph.
|
|
61
|
+
The PrePostProcessor will handle merging the two.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
graph_outputs_to_maintain: List of output names to maintain in the graph by additional effort.
|
|
65
|
+
For outputs having multiple consumers, these outputs will be consumed by default and prevent
|
|
66
|
+
connection from the subsequent steps.
|
|
67
|
+
This outputs is generated by IOEntryValuePreserver.
|
|
68
|
+
"""
|
|
69
|
+
|
|
70
|
+
onnx_opset = checker_context.opset_imports[""]
|
|
71
|
+
graph_for_step = self._create_graph_for_step(graph, onnx_opset)
|
|
72
|
+
onnx.checker.check_graph(graph_for_step, checker_context)
|
|
73
|
+
|
|
74
|
+
# prefix the graph for this step to guarantee no clashes of value names with the existing graph
|
|
75
|
+
onnx.compose.add_prefix_graph(graph_for_step, self._prefix, inplace=True)
|
|
76
|
+
result = self.__merge(graph, graph_for_step, graph_outputs_to_maintain)
|
|
77
|
+
|
|
78
|
+
# update self.output_names to the prefixed names so that when we connect later Steps the values match
|
|
79
|
+
new_outputs = [self._prefix + o for o in self.output_names]
|
|
80
|
+
result_outputs = [o.name for o in result.output]
|
|
81
|
+
|
|
82
|
+
# sanity check that all of our outputs are in the merged graph
|
|
83
|
+
for o in new_outputs:
|
|
84
|
+
assert o in result_outputs
|
|
85
|
+
|
|
86
|
+
self.output_names = new_outputs
|
|
87
|
+
|
|
88
|
+
return result
|
|
89
|
+
|
|
90
|
+
@abc.abstractmethod
|
|
91
|
+
def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
|
|
92
|
+
"""
|
|
93
|
+
Derived class should implement this and return the GraphProto containing the nodes required to
|
|
94
|
+
implement the step.
|
|
95
|
+
|
|
96
|
+
Args:
|
|
97
|
+
graph: Graph the step will be appended to. Use to determine the types and shapes of values to connect.
|
|
98
|
+
onnx_opset: The ONNX opset being targeted.
|
|
99
|
+
"""
|
|
100
|
+
pass
|
|
101
|
+
|
|
102
|
+
def __merge(self, first: onnx.GraphProto, second: onnx.GraphProto,
|
|
103
|
+
graph_outputs_to_maintain: Optional[List[str]] = None):
|
|
104
|
+
# We prefixed all the value names in `second`, so allow for that when connecting the two graphs
|
|
105
|
+
first_output = [o.name for o in first.output]
|
|
106
|
+
io_map = []
|
|
107
|
+
for o in first.output:
|
|
108
|
+
# apply the same prefix to the output from the previous step to match the prefixed graph from this step
|
|
109
|
+
prefixed_output = self._prefix + o.name
|
|
110
|
+
for i in second.input:
|
|
111
|
+
if i.name == prefixed_output:
|
|
112
|
+
io_map.append((o.name, i.name))
|
|
113
|
+
first_output.remove(o.name)
|
|
114
|
+
|
|
115
|
+
graph_outputs = first_output + [o.name for o in second.output if o.name not in first_output]
|
|
116
|
+
graph_outputs += [o for o in graph_outputs_to_maintain if o not in graph_outputs]
|
|
117
|
+
|
|
118
|
+
# merge with existing graph
|
|
119
|
+
merged_graph = onnx.compose.merge_graphs(first, second, io_map, outputs=graph_outputs)
|
|
120
|
+
|
|
121
|
+
return merged_graph
|
|
122
|
+
|
|
123
|
+
@staticmethod
|
|
124
|
+
def _elem_type_str(elem_type: int):
|
|
125
|
+
return TENSOR_TYPE_TO_ONNX_TYPE[elem_type]
|
|
126
|
+
|
|
127
|
+
@staticmethod
|
|
128
|
+
def _shape_to_str(shape: onnx.TensorShapeProto):
|
|
129
|
+
"""Returns the values from the shape as a comma separated string."""
|
|
130
|
+
|
|
131
|
+
def dim_to_str(dim):
|
|
132
|
+
if dim.HasField("dim_value"):
|
|
133
|
+
return str(dim.dim_value)
|
|
134
|
+
elif dim.HasField("dim_param"):
|
|
135
|
+
return dim.dim_param
|
|
136
|
+
else:
|
|
137
|
+
return ""
|
|
138
|
+
|
|
139
|
+
shape_str = ",".join([dim_to_str(dim) for dim in shape.dim])
|
|
140
|
+
return shape_str
|
|
141
|
+
|
|
142
|
+
def _input_tensor_type(self, graph: onnx.GraphProto, input_num: int) -> onnx.TensorProto:
|
|
143
|
+
"""Get the onnx.TensorProto for the input from the outputs of the graph we're appending to."""
|
|
144
|
+
|
|
145
|
+
input_type = None
|
|
146
|
+
for o in graph.output:
|
|
147
|
+
if o.name == self.input_names[input_num]:
|
|
148
|
+
input_type = o.type.tensor_type
|
|
149
|
+
break
|
|
150
|
+
|
|
151
|
+
if not input_type:
|
|
152
|
+
raise ValueError(f"Input {self.input_names[input_num]} was not found in outputs of graph.")
|
|
153
|
+
|
|
154
|
+
return input_type
|
|
155
|
+
|
|
156
|
+
def _get_input_type_and_shape_strs(self, graph: onnx.GraphProto, input_num: int) -> Tuple[str, str]:
|
|
157
|
+
input_type = self._input_tensor_type(graph, input_num)
|
|
158
|
+
return Step._elem_type_str(input_type.elem_type), Step._shape_to_str(input_type.shape)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class Debug(Step):
|
|
162
|
+
"""
|
|
163
|
+
Step that can be arbitrarily inserted in the pre or post processing pipeline.
|
|
164
|
+
It will make the outputs of the previous Step also become graph outputs so their value can be more easily debugged.
|
|
165
|
+
|
|
166
|
+
The output will be duplicated into two outputs, one will be renamed with a suffix "_next",
|
|
167
|
+
another will be renamed with a suffix "_debug". The "_next" outputs will feed into the next step,
|
|
168
|
+
the "_debug" outputs will become graph outputs.
|
|
169
|
+
"""
|
|
170
|
+
|
|
171
|
+
def __init__(self, num_inputs: int = 1, name: Optional[str] = None):
|
|
172
|
+
"""
|
|
173
|
+
Initialize Debug step
|
|
174
|
+
Args:
|
|
175
|
+
num_inputs: Number of inputs from previous Step to make graph outputs.
|
|
176
|
+
name: Optional name for Step. Defaults to 'Debug'
|
|
177
|
+
"""
|
|
178
|
+
self._num_inputs = num_inputs
|
|
179
|
+
input_names = [f"input{i}" for i in range(0, num_inputs)]
|
|
180
|
+
output_names = [f"debug{i}" for i in range(0, num_inputs)]
|
|
181
|
+
|
|
182
|
+
super().__init__(input_names, output_names, name)
|
|
183
|
+
|
|
184
|
+
def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
|
|
185
|
+
if self._num_inputs > len(graph.output):
|
|
186
|
+
raise ValueError(
|
|
187
|
+
f"Debug step requested {self._num_inputs} inputs, but graph only has {len(graph.output)} outputs.")
|
|
188
|
+
|
|
189
|
+
debug_offset = len(self.input_names)
|
|
190
|
+
# update output names so we preserve info from the latest input names
|
|
191
|
+
self.output_names = [f"{name}_next" for name in self.input_names]
|
|
192
|
+
self.output_names += [f"{name}_debug" for name in self.input_names]
|
|
193
|
+
|
|
194
|
+
input_str_list = []
|
|
195
|
+
output_str_list = []
|
|
196
|
+
nodes_str_list = []
|
|
197
|
+
for i in range(0, self._num_inputs):
|
|
198
|
+
input_type_str, input_shape_str = self._get_input_type_and_shape_strs(
|
|
199
|
+
graph, i)
|
|
200
|
+
|
|
201
|
+
input_str_list.append(
|
|
202
|
+
f"{input_type_str}[{input_shape_str}] {self.input_names[i]}")
|
|
203
|
+
|
|
204
|
+
output_str_list.append(
|
|
205
|
+
f"{input_type_str}[{input_shape_str}] {self.output_names[i]}")
|
|
206
|
+
output_str_list.append(
|
|
207
|
+
f"{input_type_str}[{input_shape_str}] {self.output_names[debug_offset+i]}")
|
|
208
|
+
|
|
209
|
+
nodes_str_list.append(
|
|
210
|
+
f"{self.output_names[i]} = Identity({self.input_names[i]})\n")
|
|
211
|
+
nodes_str_list.append(
|
|
212
|
+
f"{self.output_names[debug_offset+i]} = Identity({self.input_names[i]})\n")
|
|
213
|
+
|
|
214
|
+
# f-string can't have back-slash
|
|
215
|
+
node_str = '\n'.join(nodes_str_list)
|
|
216
|
+
debug_graph = onnx.parser.parse_graph(
|
|
217
|
+
f"""\
|
|
218
|
+
debug ({','.join(input_str_list)})
|
|
219
|
+
=> ({','.join(output_str_list)})
|
|
220
|
+
{{
|
|
221
|
+
{node_str}
|
|
222
|
+
}}
|
|
223
|
+
"""
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
onnx.checker.check_graph(debug_graph)
|
|
227
|
+
return debug_graph
|