onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,669 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
|
3
|
+
import numpy as np
|
|
4
|
+
import onnx
|
|
5
|
+
import torch
|
|
6
|
+
from ..helpers.torch_helper import to_tensor, to_numpy
|
|
7
|
+
from ..torch_onnx.runtime_info import first_used_last_used, RuntimeValue
|
|
8
|
+
from .report_results_comparison import ReportResultComparison
|
|
9
|
+
from . import torch_ops
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@functools.lru_cache
|
|
13
|
+
def get_kernels() -> Dict[Tuple[str, str, int], type[torch_ops.OpRunKernel]]:
|
|
14
|
+
"""
|
|
15
|
+
Retrieves all the available kernels class :class:`TorchOnnxEvaluator`
|
|
16
|
+
can use. The full list is the following.
|
|
17
|
+
|
|
18
|
+
.. runpython::
|
|
19
|
+
:showcode:
|
|
20
|
+
|
|
21
|
+
from onnx_diagnostic.reference.torch_evaluator import get_kernels
|
|
22
|
+
|
|
23
|
+
for k, v in sorted(get_kernels().items()):
|
|
24
|
+
domain, name, version = k
|
|
25
|
+
f = f"{name}({version})" if domain == "" else f"{name}[{domain}]({version})"
|
|
26
|
+
add = " " * max(25 - len(f), 0)
|
|
27
|
+
dd = " -- device dependent" if v.device_dependent() else ""
|
|
28
|
+
print(f"{f}{add} -- {v.__name__}{dd}")
|
|
29
|
+
"""
|
|
30
|
+
res = {}
|
|
31
|
+
for _k, v in torch_ops.__dict__.items():
|
|
32
|
+
if isinstance(v, type) and issubclass(v, torch_ops.OpRunKernel) and "_" in v.__name__:
|
|
33
|
+
name, version = v.__name__.split("_")
|
|
34
|
+
domain = getattr(v, "domain", "")
|
|
35
|
+
res[domain, name, int(version)] = v
|
|
36
|
+
return res
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class TorchOnnxEvaluator:
|
|
40
|
+
"""
|
|
41
|
+
Torch evaluator for onnx models.
|
|
42
|
+
The model does not stores the original proto it evaluates to avoid
|
|
43
|
+
|
|
44
|
+
:param proto: a proto
|
|
45
|
+
:param providers: where to run the model
|
|
46
|
+
:param opsets: needed if proto is a graph
|
|
47
|
+
:param functions: known local functions
|
|
48
|
+
:param verbose: verbosity level
|
|
49
|
+
:param custom_kernels: dictionary of kernels the user can defined to overwrite
|
|
50
|
+
a specific implementation: ``("", "LayerNormalization"): CustomKernel``
|
|
51
|
+
|
|
52
|
+
The class holds the following attributes:
|
|
53
|
+
|
|
54
|
+
* `providers`: providers
|
|
55
|
+
* `default_device`: default torch device
|
|
56
|
+
* `constants`: all initializers or constants
|
|
57
|
+
* `kernels`: kernels
|
|
58
|
+
* `runtime_info`: produced by :func:`first_used_last_used
|
|
59
|
+
<onnx_diagnostic.torch_onnx.runtime_info.first_used_last_used>`
|
|
60
|
+
* `last_used`: contains the list of intermediate results,
|
|
61
|
+
to remove after every node execution,
|
|
62
|
+
this avoid the memory to grow too much
|
|
63
|
+
* `functions`: local functions
|
|
64
|
+
|
|
65
|
+
The class is not multithreaded. `runtime_info` gets updated
|
|
66
|
+
by the the class. The list of available kernels is returned by function
|
|
67
|
+
:func:`onnx_diagnostic.reference.torch_evaluator.get_kernels`.
|
|
68
|
+
Example:
|
|
69
|
+
|
|
70
|
+
.. runpython::
|
|
71
|
+
:showcode:
|
|
72
|
+
|
|
73
|
+
import onnx
|
|
74
|
+
import onnx.helper as oh
|
|
75
|
+
import torch
|
|
76
|
+
from onnx_diagnostic.helpers import string_type
|
|
77
|
+
from onnx_diagnostic.reference import TorchOnnxEvaluator
|
|
78
|
+
|
|
79
|
+
TFLOAT = onnx.TensorProto.FLOAT
|
|
80
|
+
|
|
81
|
+
proto = oh.make_model(
|
|
82
|
+
oh.make_graph(
|
|
83
|
+
[
|
|
84
|
+
oh.make_node("Sigmoid", ["Y"], ["sy"]),
|
|
85
|
+
oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
|
|
86
|
+
oh.make_node("Mul", ["X", "ysy"], ["final"]),
|
|
87
|
+
],
|
|
88
|
+
"-nd-",
|
|
89
|
+
[
|
|
90
|
+
oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
|
|
91
|
+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
|
|
92
|
+
],
|
|
93
|
+
[oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
|
|
94
|
+
),
|
|
95
|
+
opset_imports=[oh.make_opsetid("", 18)],
|
|
96
|
+
ir_version=9,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
sess = TorchOnnxEvaluator(proto)
|
|
100
|
+
feeds = dict(X=torch.rand((4, 5)), Y=torch.rand((4, 5)))
|
|
101
|
+
result = sess.run(None, feeds)
|
|
102
|
+
print(string_type(result, with_shape=True, with_min_max=True))
|
|
103
|
+
|
|
104
|
+
With ``verbose=1``, the class prints out every kernel run and
|
|
105
|
+
and every result deleted along the run.
|
|
106
|
+
It shows when a result is not needed anymore. In that case,
|
|
107
|
+
it is deleted to free the memory it takes.
|
|
108
|
+
|
|
109
|
+
.. runpython::
|
|
110
|
+
:showcode:
|
|
111
|
+
|
|
112
|
+
import onnx
|
|
113
|
+
import onnx.helper as oh
|
|
114
|
+
import torch
|
|
115
|
+
from onnx_diagnostic.helpers import string_type
|
|
116
|
+
from onnx_diagnostic.reference import TorchOnnxEvaluator
|
|
117
|
+
|
|
118
|
+
TFLOAT = onnx.TensorProto.FLOAT
|
|
119
|
+
|
|
120
|
+
proto = oh.make_model(
|
|
121
|
+
oh.make_graph(
|
|
122
|
+
[
|
|
123
|
+
oh.make_node("Sigmoid", ["Y"], ["sy"]),
|
|
124
|
+
oh.make_node("Mul", ["Y", "sy"], ["ysy"]),
|
|
125
|
+
oh.make_node("Mul", ["X", "ysy"], ["final"]),
|
|
126
|
+
],
|
|
127
|
+
"-nd-",
|
|
128
|
+
[
|
|
129
|
+
oh.make_tensor_value_info("X", TFLOAT, [1, "b", "c"]),
|
|
130
|
+
oh.make_tensor_value_info("Y", TFLOAT, ["a", "b", "c"]),
|
|
131
|
+
],
|
|
132
|
+
[oh.make_tensor_value_info("final", TFLOAT, ["a", "b", "c"])],
|
|
133
|
+
),
|
|
134
|
+
opset_imports=[oh.make_opsetid("", 18)],
|
|
135
|
+
ir_version=9,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
sess = TorchOnnxEvaluator(proto, verbose=1)
|
|
139
|
+
feeds = dict(X=torch.rand((4, 5)), Y=torch.rand((4, 5)))
|
|
140
|
+
result = sess.run(None, feeds)
|
|
141
|
+
print(string_type(result, with_shape=True, with_min_max=True))
|
|
142
|
+
|
|
143
|
+
The runtime can also execute the kernel the onnx model on CUDA.
|
|
144
|
+
It follows the same logic as :class:`onnxruntime.InferenceSession`:
|
|
145
|
+
``providers=["CUDAExecutionProvider"]``.
|
|
146
|
+
It is better in that case to move the input on CUDA. The class
|
|
147
|
+
tries to move every weight on CUDA but tries to keep any tensor
|
|
148
|
+
identified as a shape in CPU. Some bugs may remain as torch
|
|
149
|
+
raises an exception when devices are expected to be the same.
|
|
150
|
+
The runtime was validated with model :epkg:`arnir0/Tiny-LLM`.
|
|
151
|
+
Next example shows how to replace a kernel with a different
|
|
152
|
+
one based on :epkg:`onnxruntime`.
|
|
153
|
+
|
|
154
|
+
.. runpython::
|
|
155
|
+
:showcode:
|
|
156
|
+
|
|
157
|
+
import numpy as np
|
|
158
|
+
import onnx
|
|
159
|
+
import onnx.helper as oh
|
|
160
|
+
import onnxruntime
|
|
161
|
+
import torch
|
|
162
|
+
from onnx_diagnostic.helpers import string_type
|
|
163
|
+
from onnx_diagnostic.helpers.torch_helper import onnx_dtype_to_torch_dtype
|
|
164
|
+
from onnx_diagnostic.reference import TorchOnnxEvaluator
|
|
165
|
+
from onnx_diagnostic.reference.torch_ops import OpRunKernel, OpRunTensor
|
|
166
|
+
|
|
167
|
+
TFLOAT16 = onnx.TensorProto.FLOAT16
|
|
168
|
+
|
|
169
|
+
class LayerNormalizationOrt(OpRunKernel):
|
|
170
|
+
"LayerNormalization based on onnxruntime"
|
|
171
|
+
|
|
172
|
+
def __init__(self, node: onnx.NodeProto, version=None, verbose=0):
|
|
173
|
+
super().__init__(node, version, verbose=verbose)
|
|
174
|
+
self.axis = self.get_attribute_int(node, "axis", -1)
|
|
175
|
+
self.epsilon = self.get_attribute_float(node, "epsilon", 1e-5)
|
|
176
|
+
self.stash_type = onnx_dtype_to_torch_dtype(
|
|
177
|
+
self.get_attribute_int(node, "stash_type", onnx.TensorProto.FLOAT)
|
|
178
|
+
)
|
|
179
|
+
self.compute_std = len(node.output) > 1
|
|
180
|
+
assert not self.compute_std, "The keren only computes the first output."
|
|
181
|
+
layer_model = oh.make_model(
|
|
182
|
+
oh.make_graph(
|
|
183
|
+
[
|
|
184
|
+
oh.make_node(
|
|
185
|
+
"LayerNormalization",
|
|
186
|
+
["X", "W", "B"],
|
|
187
|
+
["Z"],
|
|
188
|
+
axis=-1,
|
|
189
|
+
epsilon=9.999999974752427e-7,
|
|
190
|
+
)
|
|
191
|
+
],
|
|
192
|
+
"dummy",
|
|
193
|
+
[
|
|
194
|
+
oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
|
|
195
|
+
oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
|
|
196
|
+
oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
|
|
197
|
+
],
|
|
198
|
+
[oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
|
|
199
|
+
),
|
|
200
|
+
ir_version=9,
|
|
201
|
+
opset_imports=[oh.make_opsetid("", 17)],
|
|
202
|
+
)
|
|
203
|
+
self.ort_sess = onnxruntime.InferenceSession(
|
|
204
|
+
layer_model.SerializeToString(), providers=["CUDAExecutionProvider"]
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
def run(self, x, scale, bias=None):
|
|
208
|
+
print(f"-- running {self.__class__.__name__}")
|
|
209
|
+
feeds = dict(X=x, W=scale)
|
|
210
|
+
if bias is not None:
|
|
211
|
+
feeds["B"] = bias
|
|
212
|
+
feeds = {k: v.tensor.detach().cpu().numpy() for k, v in feeds.items()}
|
|
213
|
+
got = self.ort_sess.run(None, feeds)[0]
|
|
214
|
+
return OpRunTensor(torch.from_numpy(got).to(x.dtype).to(x.device))
|
|
215
|
+
|
|
216
|
+
# This kernel is tested on this model.
|
|
217
|
+
model = oh.make_model(
|
|
218
|
+
oh.make_graph(
|
|
219
|
+
[
|
|
220
|
+
oh.make_node(
|
|
221
|
+
"LayerNormalization",
|
|
222
|
+
["X", "W", "B"],
|
|
223
|
+
["ln"],
|
|
224
|
+
axis=-1,
|
|
225
|
+
epsilon=9.999999974752427e-7,
|
|
226
|
+
),
|
|
227
|
+
oh.make_node(
|
|
228
|
+
"Add", ["ln", "W"], ["Z"], axis=-1, epsilon=9.999999974752427e-7
|
|
229
|
+
),
|
|
230
|
+
],
|
|
231
|
+
"dummy",
|
|
232
|
+
[
|
|
233
|
+
oh.make_tensor_value_info("X", TFLOAT16, ["b", "c", "d"]),
|
|
234
|
+
oh.make_tensor_value_info("W", TFLOAT16, ["d"]),
|
|
235
|
+
oh.make_tensor_value_info("B", TFLOAT16, ["d"]),
|
|
236
|
+
],
|
|
237
|
+
[oh.make_tensor_value_info("Z", TFLOAT16, ["b", "c", "d"])],
|
|
238
|
+
),
|
|
239
|
+
ir_version=9,
|
|
240
|
+
opset_imports=[oh.make_opsetid("", 17)],
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
torch_sess = TorchOnnxEvaluator(
|
|
244
|
+
model,
|
|
245
|
+
custom_kernels={("", "LayerNormalization"): LayerNormalizationOrt},
|
|
246
|
+
verbose=1,
|
|
247
|
+
)
|
|
248
|
+
feeds = dict(
|
|
249
|
+
zip(
|
|
250
|
+
torch_sess.input_names,
|
|
251
|
+
[
|
|
252
|
+
torch.rand(3, 4, 5, dtype=torch.float16),
|
|
253
|
+
torch.abs(torch.rand(5, dtype=torch.float16)),
|
|
254
|
+
torch.rand(5, dtype=torch.float16),
|
|
255
|
+
],
|
|
256
|
+
)
|
|
257
|
+
)
|
|
258
|
+
res = torch_sess.run(None, feeds)
|
|
259
|
+
print(string_type(res, with_shape=True, with_min_max=True))
|
|
260
|
+
"""
|
|
261
|
+
|
|
262
|
+
class IO:
|
|
263
|
+
"IO"
|
|
264
|
+
|
|
265
|
+
def __init__(self, name: str, type: int, shape: Tuple[Union[str, int], ...]):
|
|
266
|
+
self.name = name
|
|
267
|
+
self.type = type
|
|
268
|
+
self.shape = shape
|
|
269
|
+
|
|
270
|
+
@classmethod
|
|
271
|
+
def _on_cuda(cls, providers) -> int:
|
|
272
|
+
if not providers:
|
|
273
|
+
return -1
|
|
274
|
+
for p in providers:
|
|
275
|
+
if p == "CUDAExecutionProvider":
|
|
276
|
+
return 0
|
|
277
|
+
if isinstance(p, tuple) and p[0] == "CUDAExecutionProvider":
|
|
278
|
+
return p[1]["device_id"]
|
|
279
|
+
return -1
|
|
280
|
+
|
|
281
|
+
def __init__(
|
|
282
|
+
self,
|
|
283
|
+
proto: Union[onnx.FunctionProto, onnx.GraphProto, onnx.ModelProto],
|
|
284
|
+
providers: Tuple[str, ...] = ("CPUExecutionProvider",),
|
|
285
|
+
opsets: Optional[Dict[str, int]] = None,
|
|
286
|
+
local_functions: Optional[Dict[Tuple[str, str], "TorchOnnxEvaluator"]] = None,
|
|
287
|
+
verbose: int = 0,
|
|
288
|
+
custom_kernels: Optional[Dict[Tuple[str, str], type[torch_ops.OpRunKernel]]] = None,
|
|
289
|
+
):
|
|
290
|
+
self.providers = providers
|
|
291
|
+
self.constants: Dict[str, torch.Tensor] = {}
|
|
292
|
+
self.kernels: List[Optional[torch_ops.OpRunKernel]] = []
|
|
293
|
+
self.functions = local_functions.copy() if local_functions else {}
|
|
294
|
+
self.CPU = torch.tensor([0]).to("cpu").device
|
|
295
|
+
self.verbose = verbose
|
|
296
|
+
self.custom_kernels = custom_kernels or {}
|
|
297
|
+
dev = self._on_cuda(providers)
|
|
298
|
+
if dev < 0:
|
|
299
|
+
self.default_device = self.CPU
|
|
300
|
+
self.CUDA = None
|
|
301
|
+
else:
|
|
302
|
+
self.CUDA = torch.tensor([0]).to(f"cuda:{dev}").device
|
|
303
|
+
self.default_device = self.CUDA
|
|
304
|
+
|
|
305
|
+
if isinstance(proto, str):
|
|
306
|
+
proto = onnx.load(proto)
|
|
307
|
+
if isinstance(proto, onnx.ModelProto):
|
|
308
|
+
assert opsets is None, "proto is a model, opsets must be None in that case"
|
|
309
|
+
assert not proto.graph.sparse_initializer, "sparse_initializer not support yet"
|
|
310
|
+
self.opsets = {d.domain: d.version for d in proto.opset_import}
|
|
311
|
+
for f in proto.functions:
|
|
312
|
+
self.functions[f.domain, f.name] = self.__class__(
|
|
313
|
+
f,
|
|
314
|
+
providers=providers,
|
|
315
|
+
local_functions=self.functions,
|
|
316
|
+
verbose=self.verbose,
|
|
317
|
+
)
|
|
318
|
+
self._build_initializers(proto.graph.initializer)
|
|
319
|
+
self._build_initializers(proto.graph.node)
|
|
320
|
+
self._build_kernels(proto.graph.node)
|
|
321
|
+
self.input_names = [i.name for i in proto.graph.input]
|
|
322
|
+
self.output_names = [i.name for i in proto.graph.output]
|
|
323
|
+
self._io_input_names = [
|
|
324
|
+
self.IO(
|
|
325
|
+
name=i.name,
|
|
326
|
+
type=i.type.tensor_type.elem_type,
|
|
327
|
+
shape=tuple(
|
|
328
|
+
d.dim_param or d.dim_value for d in i.type.tensor_type.shape.dim
|
|
329
|
+
),
|
|
330
|
+
)
|
|
331
|
+
for i in proto.graph.input
|
|
332
|
+
]
|
|
333
|
+
self._io_output_names = [
|
|
334
|
+
self.IO(
|
|
335
|
+
name=i.name,
|
|
336
|
+
type=i.type.tensor_type.elem_type,
|
|
337
|
+
shape=tuple(
|
|
338
|
+
d.dim_param or d.dim_value for d in i.type.tensor_type.shape.dim
|
|
339
|
+
),
|
|
340
|
+
)
|
|
341
|
+
for i in proto.graph.output
|
|
342
|
+
]
|
|
343
|
+
elif isinstance(proto, onnx.GraphProto):
|
|
344
|
+
assert opsets, "opsets must be specified if proto is a graph"
|
|
345
|
+
assert not proto.sparse_initializer, "sparse_initializer not support yet"
|
|
346
|
+
self.opsets = opsets
|
|
347
|
+
self._build_initializers(proto.initializer)
|
|
348
|
+
self._build_initializers(proto.node)
|
|
349
|
+
self._build_kernels(proto.node)
|
|
350
|
+
self.input_names = [i.name for i in proto.input]
|
|
351
|
+
self.output_names = [i.name for i in proto.output]
|
|
352
|
+
elif isinstance(proto, onnx.FunctionProto):
|
|
353
|
+
assert opsets is None, "proto is a model, opsets must be None in that case"
|
|
354
|
+
self.opsets = {d.domain: d.version for d in proto.opset_import}
|
|
355
|
+
self._build_initializers(proto.node)
|
|
356
|
+
self._build_kernels(proto.node)
|
|
357
|
+
self.input_names = list(proto.input)
|
|
358
|
+
self.output_names = list(proto.output)
|
|
359
|
+
else:
|
|
360
|
+
raise TypeError(f"Unexpected type {type(proto)} for proto")
|
|
361
|
+
|
|
362
|
+
self.runtime_info = first_used_last_used(proto, constant_as_initializer=True)
|
|
363
|
+
self.last_used: List[List[str]] = [[] for _ in self.kernels]
|
|
364
|
+
for name, info in self.runtime_info.items():
|
|
365
|
+
assert isinstance(info.last_used, int) or info.is_input, (
|
|
366
|
+
f"Missing field last_used in {info!r}, last_used={info.last_used!r}, "
|
|
367
|
+
f"This may mean the node is unused and it should be removed."
|
|
368
|
+
)
|
|
369
|
+
if info.last_used is None:
|
|
370
|
+
# Not used.
|
|
371
|
+
self.last_used[0].append(name)
|
|
372
|
+
elif not info.is_output and not info.is_initializer:
|
|
373
|
+
self.last_used[info.last_used].append(name)
|
|
374
|
+
|
|
375
|
+
def get_inputs(self):
|
|
376
|
+
"Same API than onnxruntime."
|
|
377
|
+
assert hasattr(self, "_io_input_names"), "Missing attribute '_io_input_names'."
|
|
378
|
+
return self._io_input_names
|
|
379
|
+
|
|
380
|
+
def get_outputs(self):
|
|
381
|
+
"Same API than onnxruntime."
|
|
382
|
+
assert hasattr(self, "_io_output_names"), "Missing attribute '_io_output_names'."
|
|
383
|
+
return self._io_output_names
|
|
384
|
+
|
|
385
|
+
@property
|
|
386
|
+
def on_cuda(self) -> bool:
|
|
387
|
+
"Tells if the default device is CUDA."
|
|
388
|
+
return self.default_device == self.CUDA
|
|
389
|
+
|
|
390
|
+
def _build_initializers(self, inits: Sequence[Union[onnx.NodeProto, onnx.TensorProto]]):
|
|
391
|
+
for init in inits:
|
|
392
|
+
if isinstance(init, onnx.TensorProto):
|
|
393
|
+
self.constants[init.name] = to_tensor(init).to(self.default_device)
|
|
394
|
+
elif (
|
|
395
|
+
isinstance(init, onnx.NodeProto)
|
|
396
|
+
and init.op_type == "Constant"
|
|
397
|
+
and init.domain == ""
|
|
398
|
+
):
|
|
399
|
+
value = None
|
|
400
|
+
for att in init.attribute:
|
|
401
|
+
if att.name == "value":
|
|
402
|
+
value = to_tensor(att.t).to(self.default_device)
|
|
403
|
+
elif att.name == "value_floats":
|
|
404
|
+
value = torch.tensor(list(att.floats), dtype=torch.float32).to(
|
|
405
|
+
self.default_device
|
|
406
|
+
)
|
|
407
|
+
assert value is not None, f"No attribute value in node {init}"
|
|
408
|
+
self.constants[init.output[0]] = value
|
|
409
|
+
|
|
410
|
+
def _build_kernels(self, nodes: Sequence[onnx.NodeProto]):
|
|
411
|
+
kernels = get_kernels()
|
|
412
|
+
self.kernels.clear()
|
|
413
|
+
for node in nodes:
|
|
414
|
+
kernel_kwargs = dict(verbose=max(0, self.verbose - 1))
|
|
415
|
+
opset = self.opsets[node.domain]
|
|
416
|
+
key = node.domain, node.op_type, opset
|
|
417
|
+
if key[:2] in self.custom_kernels:
|
|
418
|
+
cls = self.custom_kernels[key[:2]]
|
|
419
|
+
ags = [self.default_device] if cls.device_dependent() else []
|
|
420
|
+
kws = dict(parent=self) if cls.has_subgraphs() else {}
|
|
421
|
+
kws.update(kernel_kwargs) # type: ignore[arg-type]
|
|
422
|
+
kernel2 = cls(node, opset, *ags, **kws) # type: ignore[arg-type]
|
|
423
|
+
self.kernels.append(kernel2)
|
|
424
|
+
continue
|
|
425
|
+
|
|
426
|
+
if (node.domain, node.op_type) in self.functions:
|
|
427
|
+
kernel = torch_ops.OpRunFunction(
|
|
428
|
+
self.functions[node.domain, node.op_type],
|
|
429
|
+
node,
|
|
430
|
+
self.opsets[node.domain],
|
|
431
|
+
**kernel_kwargs,
|
|
432
|
+
)
|
|
433
|
+
self.kernels.append(kernel)
|
|
434
|
+
continue
|
|
435
|
+
|
|
436
|
+
if node.op_type == "Constant" and node.domain == "":
|
|
437
|
+
# Treated as a constant.
|
|
438
|
+
self.kernels.append(None)
|
|
439
|
+
continue
|
|
440
|
+
|
|
441
|
+
while key not in kernels and opset > 0:
|
|
442
|
+
opset -= 1
|
|
443
|
+
key = node.domain, node.op_type, opset
|
|
444
|
+
assert key in kernels, (
|
|
445
|
+
f"Missing kernel for node type {node.op_type!r} from domain {node.domain!r}, "
|
|
446
|
+
f"local functions={sorted(self.functions)}"
|
|
447
|
+
)
|
|
448
|
+
cls = kernels[key]
|
|
449
|
+
ags = [self.default_device] if cls.device_dependent() else []
|
|
450
|
+
kws = dict(parent=self) if cls.has_subgraphs() else {}
|
|
451
|
+
kws.update(kernel_kwargs) # type: ignore[arg-type]
|
|
452
|
+
kernel2 = cls(node, opset, *ags, **kws) # type: ignore[arg-type]
|
|
453
|
+
self.kernels.append(kernel2)
|
|
454
|
+
|
|
455
|
+
def run(
|
|
456
|
+
self,
|
|
457
|
+
outputs: Optional[List[str]],
|
|
458
|
+
feeds: Union[Dict[str, torch.Tensor], Dict[str, np.ndarray]],
|
|
459
|
+
report_cmp: Optional[ReportResultComparison] = None,
|
|
460
|
+
) -> Union[List[Optional[torch.Tensor]], List[Optional[np.ndarray]]]:
|
|
461
|
+
"""
|
|
462
|
+
Runs the ONNX model.
|
|
463
|
+
|
|
464
|
+
:param outputs: outputs required
|
|
465
|
+
:param feeds: inputs
|
|
466
|
+
:param report_cmp: used as a reference,
|
|
467
|
+
every intermediate results is compare to every existing one,
|
|
468
|
+
if not empty, it is an instance of
|
|
469
|
+
:class:`onnx_diagnostic.reference.ReportResultComparison`
|
|
470
|
+
:return: output tensors.
|
|
471
|
+
"""
|
|
472
|
+
use_numpy = any(isinstance(t, np.ndarray) for t in feeds.values())
|
|
473
|
+
if use_numpy:
|
|
474
|
+
feeds = {k: torch.from_numpy(v) for k, v in feeds.items()}
|
|
475
|
+
if outputs is None:
|
|
476
|
+
outputs = self.output_names
|
|
477
|
+
|
|
478
|
+
# sets constants
|
|
479
|
+
for k, v in self.constants.items():
|
|
480
|
+
r = self.runtime_info[k]
|
|
481
|
+
if not r.has_value:
|
|
482
|
+
r.set_value(
|
|
483
|
+
torch_ops.OpRunTensor(
|
|
484
|
+
v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
|
|
485
|
+
is_constant=True,
|
|
486
|
+
may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
|
|
487
|
+
)
|
|
488
|
+
)
|
|
489
|
+
if self.verbose:
|
|
490
|
+
print(f"+C {r.name}: {r.string_type()}")
|
|
491
|
+
|
|
492
|
+
# inputs
|
|
493
|
+
for k, v in feeds.items():
|
|
494
|
+
r = self.runtime_info[k]
|
|
495
|
+
r.set_value(
|
|
496
|
+
torch_ops.OpRunTensor(
|
|
497
|
+
v.to(self.CUDA) if not r.is_shape and self.on_cuda else v,
|
|
498
|
+
is_constant=False,
|
|
499
|
+
may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
|
|
500
|
+
)
|
|
501
|
+
)
|
|
502
|
+
if self.verbose:
|
|
503
|
+
print(f"+I {r.name}: {r.string_type()}")
|
|
504
|
+
|
|
505
|
+
# node execution
|
|
506
|
+
for it, kernel in enumerate(self.kernels):
|
|
507
|
+
if kernel is not None:
|
|
508
|
+
if self.verbose:
|
|
509
|
+
print(
|
|
510
|
+
f"{kernel.__class__.__name__}"
|
|
511
|
+
f"({', '.join(kernel.input)}) -> "
|
|
512
|
+
f"{', '.join(kernel.output)}"
|
|
513
|
+
)
|
|
514
|
+
# kernel execution
|
|
515
|
+
inputs = [(self.runtime_info[i].value if i else None) for i in kernel.input]
|
|
516
|
+
if kernel.has_subgraphs():
|
|
517
|
+
res = kernel.run(*inputs, context=self.runtime_info) # type: ignore[call-arg]
|
|
518
|
+
else:
|
|
519
|
+
res = kernel.run(*inputs)
|
|
520
|
+
if isinstance(res, tuple):
|
|
521
|
+
# outputs
|
|
522
|
+
assert all(isinstance(o, torch_ops.OpRunValue) for o in res), (
|
|
523
|
+
f"Unexpected output type {[type(o) for o in res]} "
|
|
524
|
+
f"for kernel {type(kernel)}."
|
|
525
|
+
)
|
|
526
|
+
for name, t in zip(kernel.output, res):
|
|
527
|
+
self.runtime_info[name].set_value(t)
|
|
528
|
+
if self.verbose:
|
|
529
|
+
for name in kernel.output:
|
|
530
|
+
print(f"+R {name}: {self.runtime_info[name].string_type()}")
|
|
531
|
+
else:
|
|
532
|
+
assert isinstance(
|
|
533
|
+
res, torch_ops.OpRunValue
|
|
534
|
+
), f"Unexpected output type {type(res)} for kernel {type(kernel)}."
|
|
535
|
+
self.runtime_info[kernel.output[0]].set_value(res)
|
|
536
|
+
if self.verbose:
|
|
537
|
+
print(
|
|
538
|
+
f"+R {kernel.output[0]}: "
|
|
539
|
+
f"{self.runtime_info[kernel.output[0]].string_type()}"
|
|
540
|
+
)
|
|
541
|
+
if report_cmp:
|
|
542
|
+
reported = report_cmp.report(
|
|
543
|
+
dict(
|
|
544
|
+
zip(
|
|
545
|
+
kernel.output,
|
|
546
|
+
(
|
|
547
|
+
tuple((r.tensor if r else None) for r in res) # type: ignore[attr-defined]
|
|
548
|
+
if isinstance(res, tuple)
|
|
549
|
+
else ((res.tensor if res else None),) # type: ignore[attr-defined]
|
|
550
|
+
),
|
|
551
|
+
)
|
|
552
|
+
)
|
|
553
|
+
)
|
|
554
|
+
if self.verbose > 1:
|
|
555
|
+
print(f" -- report {len(reported)} comparisons")
|
|
556
|
+
|
|
557
|
+
# free intermediate results
|
|
558
|
+
for name in self.last_used[it]:
|
|
559
|
+
self.runtime_info[name].clean_value()
|
|
560
|
+
if self.verbose:
|
|
561
|
+
print(f"- clean {name}")
|
|
562
|
+
|
|
563
|
+
assert all(
|
|
564
|
+
self.runtime_info[o].value is not None for o in outputs
|
|
565
|
+
), "Not implemented yet when one output is None."
|
|
566
|
+
fres = [self.runtime_info[o].value.tensor for o in outputs] # type: ignore[union-attr]
|
|
567
|
+
if self.verbose:
|
|
568
|
+
print(f"++ outputs {', '.join(outputs)}")
|
|
569
|
+
|
|
570
|
+
# clean previous execution
|
|
571
|
+
for k in feeds:
|
|
572
|
+
self.runtime_info[k].clean_value()
|
|
573
|
+
if self.verbose:
|
|
574
|
+
print(f"- clean {k}")
|
|
575
|
+
for o in outputs:
|
|
576
|
+
self.runtime_info[o].clean_value()
|
|
577
|
+
if self.verbose:
|
|
578
|
+
print(f"- clean {o}")
|
|
579
|
+
|
|
580
|
+
if use_numpy:
|
|
581
|
+
return [None if a is None else to_numpy(a) for a in fres]
|
|
582
|
+
return fres
|
|
583
|
+
|
|
584
|
+
def run_with_values(
|
|
585
|
+
self,
|
|
586
|
+
*args: Optional[torch_ops.OpRunTensor],
|
|
587
|
+
context: Optional[Dict[str, RuntimeValue]] = None,
|
|
588
|
+
) -> Union[torch_ops.OpRunValue, Tuple[torch_ops.OpRunValue, ...]]:
|
|
589
|
+
"""
|
|
590
|
+
Runs the ONNX model. The signature is different.
|
|
591
|
+
This method is called by every kernel hokding a subgraph.
|
|
592
|
+
The local variables are stored in `context`.
|
|
593
|
+
|
|
594
|
+
:param args: inputs
|
|
595
|
+
:param context: local context for the execution of subgraphs
|
|
596
|
+
:return: output OpRunTensor
|
|
597
|
+
"""
|
|
598
|
+
assert all(
|
|
599
|
+
isinstance(a, torch_ops.OpRunValue) for a in args
|
|
600
|
+
), f"Unexpected type in args: {[type(a) for a in args]}"
|
|
601
|
+
outputs = self.output_names
|
|
602
|
+
context = context or {}
|
|
603
|
+
|
|
604
|
+
# sets constants
|
|
605
|
+
for k, v in self.constants.items():
|
|
606
|
+
r = self.runtime_info[k]
|
|
607
|
+
if not r.has_value:
|
|
608
|
+
r.set_value(
|
|
609
|
+
torch_ops.OpRunTensor(
|
|
610
|
+
v.to(self.CUDA) if r.is_shape is False and self.on_cuda else v,
|
|
611
|
+
is_constant=True,
|
|
612
|
+
may_cpu=len(v.shape) == 1 and v.numel() < 8 and v.dtype == torch.int64,
|
|
613
|
+
)
|
|
614
|
+
)
|
|
615
|
+
|
|
616
|
+
# inputs
|
|
617
|
+
for k, v in zip(self.input_names, args):
|
|
618
|
+
r = self.runtime_info[k]
|
|
619
|
+
r.set_value(
|
|
620
|
+
torch_ops.OpRunTensor(None) if v is None else v.__class__(v.tensor_or_sequence)
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
# node execution
|
|
624
|
+
for it, kernel in enumerate(self.kernels):
|
|
625
|
+
if kernel is not None:
|
|
626
|
+
# kernel execution
|
|
627
|
+
inputs = [
|
|
628
|
+
(
|
|
629
|
+
(
|
|
630
|
+
self.runtime_info[i].value
|
|
631
|
+
if i in self.runtime_info
|
|
632
|
+
else context[i].value
|
|
633
|
+
)
|
|
634
|
+
if i
|
|
635
|
+
else None
|
|
636
|
+
)
|
|
637
|
+
for i in kernel.input
|
|
638
|
+
]
|
|
639
|
+
res = kernel.run(*inputs)
|
|
640
|
+
if isinstance(res, tuple):
|
|
641
|
+
# outputs
|
|
642
|
+
assert all(isinstance(o, torch_ops.OpRunTensor) for o in res), (
|
|
643
|
+
f"Unexpected output type {[type(o) for o in res]} "
|
|
644
|
+
f"for kernel {type(kernel)}."
|
|
645
|
+
)
|
|
646
|
+
for name, t in zip(kernel.output, res):
|
|
647
|
+
self.runtime_info[name].set_value(t)
|
|
648
|
+
else:
|
|
649
|
+
assert isinstance(
|
|
650
|
+
res, torch_ops.OpRunValue
|
|
651
|
+
), f"Unexpected output type {type(res)} for kernel {type(kernel)}."
|
|
652
|
+
self.runtime_info[kernel.output[0]].set_value(res)
|
|
653
|
+
|
|
654
|
+
# free intermediate results
|
|
655
|
+
for name in self.last_used[it]:
|
|
656
|
+
self.runtime_info[name].clean_value()
|
|
657
|
+
|
|
658
|
+
assert all(
|
|
659
|
+
self.runtime_info[o].value is not None for o in outputs
|
|
660
|
+
), "Not implemented yet when one output is None."
|
|
661
|
+
res2 = [self.runtime_info[o].value.copy() for o in outputs] # type: ignore[assignment, union-attr]
|
|
662
|
+
|
|
663
|
+
# clean previous execution
|
|
664
|
+
for k in self.input_names:
|
|
665
|
+
self.runtime_info[k].clean_value()
|
|
666
|
+
for o in self.output_names:
|
|
667
|
+
self.runtime_info[o].clean_value()
|
|
668
|
+
|
|
669
|
+
return res2[0] if len(res2) == 1 else tuple(res2) # type: ignore[index, return-value, arg-type]
|