onnxruntime_extensions 0.14.0__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. onnxruntime_extensions/__init__.py +82 -0
  2. onnxruntime_extensions/_cuops.py +564 -0
  3. onnxruntime_extensions/_extensions_pydll.cpython-313-darwin.so +0 -0
  4. onnxruntime_extensions/_extensions_pydll.pyi +45 -0
  5. onnxruntime_extensions/_hf_cvt.py +331 -0
  6. onnxruntime_extensions/_ocos.py +133 -0
  7. onnxruntime_extensions/_ortapi2.py +274 -0
  8. onnxruntime_extensions/_torch_cvt.py +231 -0
  9. onnxruntime_extensions/_version.py +2 -0
  10. onnxruntime_extensions/cmd.py +66 -0
  11. onnxruntime_extensions/cvt.py +306 -0
  12. onnxruntime_extensions/onnxprocess/__init__.py +12 -0
  13. onnxruntime_extensions/onnxprocess/_builder.py +53 -0
  14. onnxruntime_extensions/onnxprocess/_onnx_ops.py +1507 -0
  15. onnxruntime_extensions/onnxprocess/_session.py +355 -0
  16. onnxruntime_extensions/onnxprocess/_tensor.py +628 -0
  17. onnxruntime_extensions/onnxprocess/torch_wrapper.py +31 -0
  18. onnxruntime_extensions/pnp/__init__.py +13 -0
  19. onnxruntime_extensions/pnp/_base.py +124 -0
  20. onnxruntime_extensions/pnp/_imagenet.py +65 -0
  21. onnxruntime_extensions/pnp/_nlp.py +148 -0
  22. onnxruntime_extensions/pnp/_onnx_ops.py +1544 -0
  23. onnxruntime_extensions/pnp/_torchext.py +310 -0
  24. onnxruntime_extensions/pnp/_unifier.py +45 -0
  25. onnxruntime_extensions/pnp/_utils.py +302 -0
  26. onnxruntime_extensions/pp_api.py +83 -0
  27. onnxruntime_extensions/tools/__init__.py +0 -0
  28. onnxruntime_extensions/tools/add_HuggingFace_CLIPImageProcessor_to_model.py +171 -0
  29. onnxruntime_extensions/tools/add_pre_post_processing_to_model.py +535 -0
  30. onnxruntime_extensions/tools/pre_post_processing/__init__.py +4 -0
  31. onnxruntime_extensions/tools/pre_post_processing/pre_post_processor.py +395 -0
  32. onnxruntime_extensions/tools/pre_post_processing/step.py +227 -0
  33. onnxruntime_extensions/tools/pre_post_processing/steps/__init__.py +6 -0
  34. onnxruntime_extensions/tools/pre_post_processing/steps/general.py +366 -0
  35. onnxruntime_extensions/tools/pre_post_processing/steps/nlp.py +344 -0
  36. onnxruntime_extensions/tools/pre_post_processing/steps/vision.py +1157 -0
  37. onnxruntime_extensions/tools/pre_post_processing/utils.py +139 -0
  38. onnxruntime_extensions/util.py +186 -0
  39. onnxruntime_extensions-0.14.0.dist-info/LICENSE +21 -0
  40. onnxruntime_extensions-0.14.0.dist-info/METADATA +102 -0
  41. onnxruntime_extensions-0.14.0.dist-info/RECORD +43 -0
  42. onnxruntime_extensions-0.14.0.dist-info/WHEEL +6 -0
  43. onnxruntime_extensions-0.14.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,395 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ import onnx
5
+
6
+ from onnx import version_converter
7
+ from typing import List, Tuple, Union
8
+
9
+ from .utils import (
10
+ IoMapEntry,
11
+ IOEntryValuePreserver,
12
+ create_custom_op_checker_context,
13
+ sanitize_output_names,
14
+ TENSOR_TYPE_TO_ONNX_TYPE,
15
+ )
16
+ from .step import Step
17
+
18
+
19
+ class PrePostProcessor:
20
+ """
21
+ Class to handle running all the pre/post processing steps and updating the model.
22
+ """
23
+
24
+ def __init__(self, inputs: List[onnx.ValueInfoProto] = None, onnx_opset: int = 16):
25
+ """
26
+ Create a PrePostProcessor instance.
27
+
28
+ Args:
29
+ inputs: The inputs the model will use if pre-processing is added.
30
+ onnx_opset: The ONNX opset to use.
31
+ Minimum is 16. 18 or higher is strongly preferred if image resizing is involved due to its
32
+ anti-aliasing ability.
33
+ """
34
+
35
+ if onnx_opset < 16:
36
+ raise ValueError("ONNX opset must be 16 or later.")
37
+
38
+ self._onnx_opset = onnx_opset
39
+ self._custom_op_checker_context = create_custom_op_checker_context(onnx_opset)
40
+
41
+ self.pre_processors = []
42
+ self.post_processors = []
43
+
44
+ # Connections for each pre/post processor. 1:1 mapping with entries in pre_processors/post_processors
45
+ self._pre_processor_connections = [] # type: List[List[IoMapEntry]]
46
+ self._post_processor_connections = [] # type: List[List[IoMapEntry]]
47
+
48
+ # explicitly join outputs from Steps in pre_processors to inputs of the original model
49
+ # format is Step or step name, step_idx, name of graph input/output
50
+ #
51
+ # Pre-processing we connect Step output to original model:
52
+ # - step_idx is for Step.output_names, and name is in graph.input
53
+ #
54
+ # Post-processing we connect the original model output to the Step input
55
+ # - step_idx is for Step.input_names, and name is in graph.output
56
+ self._pre_processing_joins = None # type: Union[None,List[Tuple[Union[Step, str], int, str]]]
57
+ self._post_processing_joins = None # type: Union[None,List[Tuple[Union[Step, str], int, str]]]
58
+
59
+ self._inputs = inputs if inputs else []
60
+
61
+ # preserve outputs from IOMapEntry, avoid it's consumed by the Follow-up steps.
62
+ # we now can support a output value has more than one consumers with IOEntryValuePreserver.
63
+ # IOEntryValuePreserver will preserve the output value and add it to the graph output
64
+ # until consumer step is done.
65
+ self.outputs_preserver = [] # type: List[IOEntryValuePreserver]
66
+
67
+ def add_pre_processing(self, items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]]):
68
+ """
69
+ Add the pre-processing steps. The last step is automatically joined to the original model inputs.
70
+
71
+ Options are:
72
+ Add Step with default connection of outputs from the previous step (if available) to inputs of this step.
73
+ Add tuple of Step and list of IoMapEntry instances for manual connections to previous steps. This will be
74
+ used to override any automatic connections.
75
+ If IoMapEntry.producer is None it is inferred to be the immediately previous Step.
76
+ If IoMapEntry.producer is a step name it must match the name of a previous step.
77
+ """
78
+ self.__add_processing(self.pre_processors, self._pre_processor_connections, items)
79
+
80
+ def add_post_processing(self, items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]]):
81
+ """
82
+ Add the post-processing steps. The first step is automatically joined to the original model outputs.
83
+
84
+ Options are:
85
+ Add Step with default connection of outputs from the previous step (if available) to inputs of this step.
86
+ Add tuple of Step and list of IoMapEntry instances for connections to previous steps. This will be
87
+ used to override any automatic connections.
88
+ If IoMapEntry.producer is None it is inferred to be the immediately previous Step.
89
+ If IoMapEntry.producer is a step name it must match the name of a previous step.
90
+ """
91
+ self.__add_processing(self.post_processors, self._post_processor_connections, items)
92
+
93
+ def _add_connection(self, consumer: Step, entry: IoMapEntry):
94
+ producer = self.__producer_from_step_or_str(entry.producer)
95
+
96
+ # Black does annoying things with the multi-line 'if' conditions making the code far less readable
97
+ # fmt: off
98
+ if not ((producer in self.pre_processors or producer in self.post_processors) and
99
+ (consumer in self.pre_processors or consumer in self.post_processors)):
100
+ raise ValueError("Producer and Consumer processors must both be registered")
101
+
102
+ if producer in self.pre_processors:
103
+ if (consumer in self.pre_processors and
104
+ self.pre_processors.index(producer) > self.pre_processors.index(consumer)):
105
+ raise ValueError("Producer was registered after consumer and cannot be connected")
106
+ elif producer in self.post_processors:
107
+ if consumer not in self.post_processors:
108
+ raise ValueError("Cannot connect pre-processor consumer with post-processor producer")
109
+ elif self.post_processors.index(producer) > self.post_processors.index(consumer):
110
+ raise ValueError("Producer was registered after consumer and cannot be connected")
111
+ # fmt: on
112
+
113
+ assert isinstance(producer, Step)
114
+ consumer.connect(entry)
115
+
116
+
117
+ def run(self, model: onnx.ModelProto):
118
+ """
119
+ Update the model with the graph from each step in the pre and post processing pipelines.
120
+
121
+ Args:
122
+ model: model to add pre/post processing to.
123
+
124
+ Returns:
125
+ model with pre/post processing in it.
126
+ """
127
+
128
+ # update the input model to the ONNX opset we're using. this is required as we implement the steps based on
129
+ # the operator specs for this opset.
130
+ model_opset = [
131
+ entry.version for entry in model.opset_import if entry.domain == "" or entry.domain == "ai.onnx"
132
+ ][0]
133
+
134
+ if model_opset > self._onnx_opset:
135
+ # It will probably work if the user updates PRE_POST_PROCESSING_ONNX_OPSET to match the model
136
+ # but there are no guarantees.
137
+ # Would only break if ONNX operators used in the pre/post processing graphs have had spec changes.
138
+ raise ValueError(f"Model opset is {model_opset} which is newer than the opset used by this script.")
139
+ elif model_opset < self._onnx_opset:
140
+ model = onnx.version_converter.convert_version(model, self._onnx_opset)
141
+
142
+ def name_nodes(new_graph: onnx.GraphProto, prefix: str):
143
+ # simple helper so all nodes are named. this makes it far easier to debug any issues.
144
+ idx = 0
145
+ for n in new_graph.node:
146
+ if not n.name:
147
+ n.name = prefix + str(idx)
148
+ idx += 1
149
+
150
+ def preserved_apply(processor: Step, *args):
151
+ # Trying to activate the IOEntryValuePreserver and preserve outputs.
152
+ # and deactivate the outputs when the current graph consumed them
153
+
154
+ for preserver in self.outputs_preserver:
155
+ if preserver.consumer == processor:
156
+ preserver.is_active = False
157
+
158
+ # IOEntryValuePreserver, preserve those outputs which has multiple consumers.
159
+ # we explicitly add the output to the graph output.
160
+ graph_outputs_to_maintain = [i.output for i in self.outputs_preserver if i.is_active]
161
+ graph_for_step = processor.apply(*args, graph_outputs_to_maintain=graph_outputs_to_maintain)
162
+
163
+ for preserver in self.outputs_preserver:
164
+ if preserver.producer == processor:
165
+ preserver.is_active = True
166
+ preserver.output = processor.output_names[preserver.producer_idx]
167
+ return graph_for_step
168
+
169
+ def connect_and_run(graph: onnx.GraphProto, processor: Step, connections: List[IoMapEntry]):
170
+ for connection in connections:
171
+ assert connection.producer
172
+ self._add_connection(processor, connection)
173
+
174
+ return preserved_apply(processor, graph, self._custom_op_checker_context)
175
+
176
+ # fix any invalid output names now if we're adding post-processing as the onnx parse_graph can't handle them
177
+ if self.post_processors:
178
+ sanitize_output_names(model.graph)
179
+
180
+ graph = model.graph
181
+ # add pre-processing
182
+ if self.pre_processors:
183
+ # create empty graph with pass through of the requested input name
184
+ pre_process_graph = onnx.GraphProto()
185
+ for i in self._inputs:
186
+ pre_process_graph.input.append(i)
187
+ pre_process_graph.output.append(i)
188
+
189
+ # connect up the graph input names to the first pre-processing step based on order
190
+ self.pre_processors[0]._connect_graph_inputs([vi.name for vi in self._inputs])
191
+
192
+ for idx, step in enumerate(self.pre_processors):
193
+ pre_process_graph = connect_and_run(pre_process_graph, step, self._pre_processor_connections[idx])
194
+
195
+ # name all the nodes for easier debugging
196
+ name_nodes(pre_process_graph, "pre_process_")
197
+
198
+ if not self._pre_processing_joins:
199
+ # default to 1:1 between outputs of last step with inputs of original model
200
+ last_step = self.pre_processors[-1]
201
+ num_entries = min(len(last_step.output_names), len(graph.input))
202
+ self._pre_processing_joins = [(last_step, i, graph.input[i].name) for i in range(0, num_entries)]
203
+
204
+ # map the pre-processing outputs to graph inputs
205
+ # we may need a natty way to get possible outputs after merge_graphs
206
+ step_graph_outputs = [o.name for o in pre_process_graph.output]
207
+ io_map = [] # type: List[Tuple[str, str]]
208
+ for step, step_idx, graph_input in self._pre_processing_joins:
209
+ io_map.append((step.output_names[step_idx], graph_input))
210
+ step_graph_outputs.remove((step.output_names[step_idx]))
211
+
212
+ # add outputs from previous IoMapEntry producers to maintain them as graph outputs
213
+ # until consumed by the final Step that requires them.
214
+ step_graph_outputs += [
215
+ o.name for o in graph.output if o.name not in step_graph_outputs]
216
+ external_outputs = [
217
+ i.output for i in self.outputs_preserver if i.is_active and i.output not in step_graph_outputs]
218
+ if external_outputs:
219
+ step_graph_outputs.extend(external_outputs)
220
+ graph = onnx.compose.merge_graphs(pre_process_graph, graph, io_map, outputs=step_graph_outputs)
221
+
222
+ # add post-processing
223
+ if self.post_processors:
224
+ orig_model_outputs = [o.name for o in model.graph.output]
225
+ graph_outputs = [o.name for o in graph.output] # this may have additional outputs from pre-processing
226
+
227
+ # create default joins if needed
228
+ if not self._post_processing_joins:
229
+ # default to 1:1 between outputs of original model with inputs of first post-processing step
230
+ first_step = self.post_processors[0]
231
+ num_entries = min(len(first_step.input_names), len(orig_model_outputs))
232
+ self._post_processing_joins = [(first_step, i, orig_model_outputs[i]) for i in range(0, num_entries)]
233
+
234
+ # update the input names for the steps to match the values produced by the model
235
+ for step, step_idx, graph_output in self._post_processing_joins:
236
+ assert graph_output in graph_outputs
237
+ step.input_names[step_idx] = graph_output
238
+
239
+ # create empty graph with the values that will be available to the post-processing
240
+ post_process_graph = onnx.GraphProto()
241
+ for o in graph.output:
242
+ post_process_graph.input.append(o)
243
+ post_process_graph.output.append(o)
244
+
245
+ for idx, step in enumerate(self.post_processors):
246
+ post_process_graph = connect_and_run(post_process_graph, step, self._post_processor_connections[idx])
247
+
248
+ name_nodes(post_process_graph, "post_process_")
249
+
250
+ # io_map should be 1:1 with the post-processing graph given we updated the step input names to match
251
+ io_map = [(o, o) for o in graph_outputs]
252
+ graph = onnx.compose.merge_graphs(graph, post_process_graph, io_map)
253
+
254
+ # Make the output names nicer by removing prefixing from naming that occurred when applying the steps
255
+ graph = PrePostProcessor.__cleanup_graph_output_names(graph)
256
+
257
+ opset_imports = [onnx.helper.make_operatorsetid(domain, opset)
258
+ for domain, opset in self._custom_op_checker_context.opset_imports.items()]
259
+ # find_min_ir_version_for doesn't support custom domains until ONNX 1.14 so extract the ONNX opset from the
260
+ # imports and only pass that in.
261
+ ir_version = onnx.helper.find_min_ir_version_for([entry for entry in opset_imports
262
+ if entry.domain == "" or entry.domain == "ai.onnx"])
263
+ new_model = onnx.helper.make_model(graph, opset_imports=opset_imports, ir_version=ir_version)
264
+
265
+ onnx.checker.check_model(new_model)
266
+
267
+ return new_model
268
+
269
+ def __add_processing(
270
+ self,
271
+ processors: List[Step],
272
+ processor_connections: List[List[IoMapEntry]],
273
+ items: List[Union[Step, Tuple[Step, List[IoMapEntry]]]],
274
+ ):
275
+ """
276
+ Add the pre/post processing steps and join with existing steps.
277
+
278
+ Args:
279
+ processors: List of processors to add items to.
280
+ processor_connections: Populated with connections between each step. 1:1 with entries in processors.
281
+ items: Items to add to processors.
282
+ Can be:
283
+ A Step instance. This will be implicitly joined to the immediately previous Step if one exists.
284
+ A tuple of (Step instance, list of IoMapEntry)
285
+ The IoMapEntry values are used to manually join an output from a producer Step to an input
286
+ of the current Step.
287
+ In each IoMapEntry, if a step name is provided the producer Step will be searched for in all
288
+ predecessor steps. It is valid for a post-processor step to consume output from a
289
+ pre-processor step.
290
+ """
291
+
292
+ for item in items:
293
+ step = None
294
+ explicit_io_map_entries = None
295
+
296
+ if isinstance(item, Step):
297
+ step = item
298
+ elif isinstance(item, tuple):
299
+ step, explicit_io_map_entries = item
300
+ else:
301
+ raise ValueError("Unexpected type " + str(type(item)))
302
+
303
+ # start with implicit joins and replace with explicitly provided ones
304
+ # this allows the user to specify the minimum number of manual joins.
305
+ io_map_entries = [None] * len(step.input_names) # type: List[Union[None,IoMapEntry]]
306
+ prev_step = None if len(processors) == 0 else processors[-1]
307
+ if prev_step:
308
+ # default is connecting as many outputs from the previous step as possible
309
+ for i in range(0, min(len(prev_step.output_names), len(step.input_names))):
310
+ io_map_entries[i] = IoMapEntry(prev_step, i, i)
311
+
312
+ # add explicit connections
313
+ if explicit_io_map_entries:
314
+ for entry in explicit_io_map_entries:
315
+ if not entry.producer:
316
+ producer = prev_step
317
+ else:
318
+ producer = self.__producer_from_step_or_str(entry.producer) # throws if not found
319
+
320
+ io_map_entries[entry.consumer_idx] = IoMapEntry(producer, entry.producer_idx, entry.consumer_idx)
321
+ self.outputs_preserver.append(IOEntryValuePreserver(producer, step, entry.producer_idx))
322
+
323
+ processors.append(step)
324
+ processor_connections.append([entry for entry in io_map_entries if entry is not None])
325
+
326
+ def __producer_from_step_or_str(self, entry: Union[Step, str]):
327
+ if isinstance(entry, Step):
328
+ return entry
329
+ if isinstance(entry, str):
330
+ match = (next((s for s in self.pre_processors if s.name == entry), None) or
331
+ next((s for s in self.post_processors if s.name == entry), None)) # fmt: skip
332
+
333
+ if not match:
334
+ raise ValueError(f"Step named {entry} was not found")
335
+
336
+ return match
337
+
338
+ @staticmethod
339
+ def __cleanup_graph_output_names(graph: onnx.GraphProto):
340
+ """
341
+ Hide the prefixing of names that happens when we merge the graphs from the pre/post processing steps.
342
+ Not essential but makes the graph outputs look far nicer.
343
+ """
344
+
345
+ # for each output create identity node to remove prefixing
346
+ io_map = []
347
+ fixes = onnx.GraphProto()
348
+
349
+ # manually handle naming clashes
350
+ input_names = set([i.name for i in graph.input])
351
+ used_names = set(input_names)
352
+ conflicts = 0
353
+
354
+ for o in graph.output:
355
+ if not o.name.startswith(Step.prefix):
356
+ continue
357
+
358
+ # we will create a small graph to do the renames so the output of the original graph will be an input
359
+ # to that 'fixer' graph
360
+ io_map.append((o.name, o.name))
361
+ clean_name = o.name
362
+ while clean_name.startswith(Step.prefix):
363
+ # output from last step will have one prefixing stage that adds Step._prefix + '_'
364
+ # e.g. '_ppp8_<orig_name>'
365
+ next_underscore = clean_name.find("_", 1)
366
+ if next_underscore > 0:
367
+ # this check shouldn't be necessary as we always add the trailing '_' when prefixing...
368
+ if len(clean_name) > next_underscore + 1:
369
+ next_underscore += 1
370
+ clean_name = clean_name[next_underscore:]
371
+
372
+ # handle things like super resolution where there's an 'image' input and 'image' output
373
+ if clean_name in input_names:
374
+ clean_name += "_out"
375
+
376
+ orig_clean_name = clean_name
377
+ while clean_name in used_names:
378
+ conflicts += 1
379
+ clean_name = f"{orig_clean_name}{conflicts}"
380
+
381
+ used_names.add(clean_name)
382
+
383
+ renamer = onnx.helper.make_node("Identity", [o.name], [clean_name], f"Rename {o.name}")
384
+ fixes.node.append(renamer)
385
+ fixes.input.append(o)
386
+
387
+ new_output = fixes.output.add()
388
+ new_output.name = clean_name
389
+ new_output.type.CopyFrom(o.type)
390
+
391
+ # merge if we have any renaming to do
392
+ if io_map:
393
+ graph = onnx.compose.merge_graphs(graph, fixes, io_map)
394
+
395
+ return graph
@@ -0,0 +1,227 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ import abc
5
+ import onnx
6
+
7
+ from onnx import parser
8
+ from typing import List, Optional, Tuple
9
+
10
+ from .utils import (
11
+ IoMapEntry,
12
+ create_custom_op_checker_context,
13
+ TENSOR_TYPE_TO_ONNX_TYPE,
14
+ )
15
+
16
+
17
+ class Step(object):
18
+ """Base class for a pre or post processing step."""
19
+
20
+ prefix = "_ppp"
21
+ _step_num = 0 # unique step number so we can prefix the naming in the graph created for the step
22
+
23
+ def __init__(self, inputs: List[str], outputs: List[str], name: Optional[str] = None):
24
+ """
25
+ Initialize the step.
26
+
27
+ Args:
28
+ inputs: List of default input names.
29
+ outputs: List of default output names.
30
+ name: Step name. Defaults to the derived class name.
31
+ """
32
+ self.step_num = Step._step_num
33
+ self.input_names = inputs
34
+ self.output_names = outputs
35
+ self.name = name if name else f"{self.__class__.__name__}"
36
+ self._prefix = f"{Step.prefix}{self.step_num}_"
37
+
38
+ Step._step_num += 1
39
+
40
+ def connect(self, entry: IoMapEntry):
41
+ """
42
+ Connect the value name from a previous step to an input of this step so they match.
43
+ This makes joining the GraphProto created by each step trivial.
44
+ """
45
+ assert len(entry.producer.output_names) >= entry.producer_idx
46
+ assert len(self.input_names) >= entry.consumer_idx
47
+ assert isinstance(entry.producer, Step)
48
+
49
+ self.input_names[entry.consumer_idx] = entry.producer.output_names[entry.producer_idx]
50
+
51
+ def _connect_graph_inputs(self, graph_inputs: List[str]):
52
+ "Internal method to connect names of the first pre-processor step with the graph inputs"
53
+ for i, input_name in enumerate(graph_inputs):
54
+ self.input_names[i] = input_name
55
+
56
+ def apply(self, graph: onnx.GraphProto,
57
+ checker_context: onnx.checker.C.CheckerContext,
58
+ graph_outputs_to_maintain: List[str]):
59
+ """
60
+ Create a graph for this step that can be appended to the provided graph.
61
+ The PrePostProcessor will handle merging the two.
62
+
63
+ Args:
64
+ graph_outputs_to_maintain: List of output names to maintain in the graph by additional effort.
65
+ For outputs having multiple consumers, these outputs will be consumed by default and prevent
66
+ connection from the subsequent steps.
67
+ This outputs is generated by IOEntryValuePreserver.
68
+ """
69
+
70
+ onnx_opset = checker_context.opset_imports[""]
71
+ graph_for_step = self._create_graph_for_step(graph, onnx_opset)
72
+ onnx.checker.check_graph(graph_for_step, checker_context)
73
+
74
+ # prefix the graph for this step to guarantee no clashes of value names with the existing graph
75
+ onnx.compose.add_prefix_graph(graph_for_step, self._prefix, inplace=True)
76
+ result = self.__merge(graph, graph_for_step, graph_outputs_to_maintain)
77
+
78
+ # update self.output_names to the prefixed names so that when we connect later Steps the values match
79
+ new_outputs = [self._prefix + o for o in self.output_names]
80
+ result_outputs = [o.name for o in result.output]
81
+
82
+ # sanity check that all of our outputs are in the merged graph
83
+ for o in new_outputs:
84
+ assert o in result_outputs
85
+
86
+ self.output_names = new_outputs
87
+
88
+ return result
89
+
90
+ @abc.abstractmethod
91
+ def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
92
+ """
93
+ Derived class should implement this and return the GraphProto containing the nodes required to
94
+ implement the step.
95
+
96
+ Args:
97
+ graph: Graph the step will be appended to. Use to determine the types and shapes of values to connect.
98
+ onnx_opset: The ONNX opset being targeted.
99
+ """
100
+ pass
101
+
102
+ def __merge(self, first: onnx.GraphProto, second: onnx.GraphProto,
103
+ graph_outputs_to_maintain: Optional[List[str]] = None):
104
+ # We prefixed all the value names in `second`, so allow for that when connecting the two graphs
105
+ first_output = [o.name for o in first.output]
106
+ io_map = []
107
+ for o in first.output:
108
+ # apply the same prefix to the output from the previous step to match the prefixed graph from this step
109
+ prefixed_output = self._prefix + o.name
110
+ for i in second.input:
111
+ if i.name == prefixed_output:
112
+ io_map.append((o.name, i.name))
113
+ first_output.remove(o.name)
114
+
115
+ graph_outputs = first_output + [o.name for o in second.output if o.name not in first_output]
116
+ graph_outputs += [o for o in graph_outputs_to_maintain if o not in graph_outputs]
117
+
118
+ # merge with existing graph
119
+ merged_graph = onnx.compose.merge_graphs(first, second, io_map, outputs=graph_outputs)
120
+
121
+ return merged_graph
122
+
123
+ @staticmethod
124
+ def _elem_type_str(elem_type: int):
125
+ return TENSOR_TYPE_TO_ONNX_TYPE[elem_type]
126
+
127
+ @staticmethod
128
+ def _shape_to_str(shape: onnx.TensorShapeProto):
129
+ """Returns the values from the shape as a comma separated string."""
130
+
131
+ def dim_to_str(dim):
132
+ if dim.HasField("dim_value"):
133
+ return str(dim.dim_value)
134
+ elif dim.HasField("dim_param"):
135
+ return dim.dim_param
136
+ else:
137
+ return ""
138
+
139
+ shape_str = ",".join([dim_to_str(dim) for dim in shape.dim])
140
+ return shape_str
141
+
142
+ def _input_tensor_type(self, graph: onnx.GraphProto, input_num: int) -> onnx.TensorProto:
143
+ """Get the onnx.TensorProto for the input from the outputs of the graph we're appending to."""
144
+
145
+ input_type = None
146
+ for o in graph.output:
147
+ if o.name == self.input_names[input_num]:
148
+ input_type = o.type.tensor_type
149
+ break
150
+
151
+ if not input_type:
152
+ raise ValueError(f"Input {self.input_names[input_num]} was not found in outputs of graph.")
153
+
154
+ return input_type
155
+
156
+ def _get_input_type_and_shape_strs(self, graph: onnx.GraphProto, input_num: int) -> Tuple[str, str]:
157
+ input_type = self._input_tensor_type(graph, input_num)
158
+ return Step._elem_type_str(input_type.elem_type), Step._shape_to_str(input_type.shape)
159
+
160
+
161
+ class Debug(Step):
162
+ """
163
+ Step that can be arbitrarily inserted in the pre or post processing pipeline.
164
+ It will make the outputs of the previous Step also become graph outputs so their value can be more easily debugged.
165
+
166
+ The output will be duplicated into two outputs, one will be renamed with a suffix "_next",
167
+ another will be renamed with a suffix "_debug". The "_next" outputs will feed into the next step,
168
+ the "_debug" outputs will become graph outputs.
169
+ """
170
+
171
+ def __init__(self, num_inputs: int = 1, name: Optional[str] = None):
172
+ """
173
+ Initialize Debug step
174
+ Args:
175
+ num_inputs: Number of inputs from previous Step to make graph outputs.
176
+ name: Optional name for Step. Defaults to 'Debug'
177
+ """
178
+ self._num_inputs = num_inputs
179
+ input_names = [f"input{i}" for i in range(0, num_inputs)]
180
+ output_names = [f"debug{i}" for i in range(0, num_inputs)]
181
+
182
+ super().__init__(input_names, output_names, name)
183
+
184
+ def _create_graph_for_step(self, graph: onnx.GraphProto, onnx_opset: int):
185
+ if self._num_inputs > len(graph.output):
186
+ raise ValueError(
187
+ f"Debug step requested {self._num_inputs} inputs, but graph only has {len(graph.output)} outputs.")
188
+
189
+ debug_offset = len(self.input_names)
190
+ # update output names so we preserve info from the latest input names
191
+ self.output_names = [f"{name}_next" for name in self.input_names]
192
+ self.output_names += [f"{name}_debug" for name in self.input_names]
193
+
194
+ input_str_list = []
195
+ output_str_list = []
196
+ nodes_str_list = []
197
+ for i in range(0, self._num_inputs):
198
+ input_type_str, input_shape_str = self._get_input_type_and_shape_strs(
199
+ graph, i)
200
+
201
+ input_str_list.append(
202
+ f"{input_type_str}[{input_shape_str}] {self.input_names[i]}")
203
+
204
+ output_str_list.append(
205
+ f"{input_type_str}[{input_shape_str}] {self.output_names[i]}")
206
+ output_str_list.append(
207
+ f"{input_type_str}[{input_shape_str}] {self.output_names[debug_offset+i]}")
208
+
209
+ nodes_str_list.append(
210
+ f"{self.output_names[i]} = Identity({self.input_names[i]})\n")
211
+ nodes_str_list.append(
212
+ f"{self.output_names[debug_offset+i]} = Identity({self.input_names[i]})\n")
213
+
214
+ # f-string can't have back-slash
215
+ node_str = '\n'.join(nodes_str_list)
216
+ debug_graph = onnx.parser.parse_graph(
217
+ f"""\
218
+ debug ({','.join(input_str_list)})
219
+ => ({','.join(output_str_list)})
220
+ {{
221
+ {node_str}
222
+ }}
223
+ """
224
+ )
225
+
226
+ onnx.checker.check_graph(debug_graph)
227
+ return debug_graph
@@ -0,0 +1,6 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ from .general import *
5
+ from .vision import *
6
+ from .nlp import *