onnx-diagnostic 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +7 -0
- onnx_diagnostic/__main__.py +4 -0
- onnx_diagnostic/_command_lines_parser.py +1141 -0
- onnx_diagnostic/api.py +15 -0
- onnx_diagnostic/doc.py +100 -0
- onnx_diagnostic/export/__init__.py +2 -0
- onnx_diagnostic/export/api.py +124 -0
- onnx_diagnostic/export/dynamic_shapes.py +1083 -0
- onnx_diagnostic/export/shape_helper.py +296 -0
- onnx_diagnostic/export/validate.py +173 -0
- onnx_diagnostic/ext_test_case.py +1290 -0
- onnx_diagnostic/helpers/__init__.py +1 -0
- onnx_diagnostic/helpers/_log_helper.py +463 -0
- onnx_diagnostic/helpers/args_helper.py +132 -0
- onnx_diagnostic/helpers/bench_run.py +450 -0
- onnx_diagnostic/helpers/cache_helper.py +687 -0
- onnx_diagnostic/helpers/config_helper.py +170 -0
- onnx_diagnostic/helpers/doc_helper.py +163 -0
- onnx_diagnostic/helpers/fake_tensor_helper.py +273 -0
- onnx_diagnostic/helpers/graph_helper.py +386 -0
- onnx_diagnostic/helpers/helper.py +1707 -0
- onnx_diagnostic/helpers/log_helper.py +2245 -0
- onnx_diagnostic/helpers/memory_peak.py +249 -0
- onnx_diagnostic/helpers/mini_onnx_builder.py +600 -0
- onnx_diagnostic/helpers/model_builder_helper.py +469 -0
- onnx_diagnostic/helpers/onnx_helper.py +1200 -0
- onnx_diagnostic/helpers/ort_session.py +736 -0
- onnx_diagnostic/helpers/rt_helper.py +476 -0
- onnx_diagnostic/helpers/torch_helper.py +987 -0
- onnx_diagnostic/reference/__init__.py +4 -0
- onnx_diagnostic/reference/evaluator.py +254 -0
- onnx_diagnostic/reference/ops/__init__.py +1 -0
- onnx_diagnostic/reference/ops/op_add_add_mul_mul.py +68 -0
- onnx_diagnostic/reference/ops/op_attention.py +60 -0
- onnx_diagnostic/reference/ops/op_average_pool_grad.py +63 -0
- onnx_diagnostic/reference/ops/op_bias_softmax.py +16 -0
- onnx_diagnostic/reference/ops/op_cast_like.py +46 -0
- onnx_diagnostic/reference/ops/op_complex.py +26 -0
- onnx_diagnostic/reference/ops/op_concat.py +15 -0
- onnx_diagnostic/reference/ops/op_constant_of_shape.py +67 -0
- onnx_diagnostic/reference/ops/op_fused_matmul.py +31 -0
- onnx_diagnostic/reference/ops/op_gather.py +29 -0
- onnx_diagnostic/reference/ops/op_gather_elements.py +45 -0
- onnx_diagnostic/reference/ops/op_gather_grad.py +12 -0
- onnx_diagnostic/reference/ops/op_memcpy_host.py +11 -0
- onnx_diagnostic/reference/ops/op_mul_sigmoid.py +23 -0
- onnx_diagnostic/reference/ops/op_negxplus1.py +8 -0
- onnx_diagnostic/reference/ops/op_qlinear_average_pool.py +40 -0
- onnx_diagnostic/reference/ops/op_qlinear_conv.py +102 -0
- onnx_diagnostic/reference/ops/op_quick_gelu.py +23 -0
- onnx_diagnostic/reference/ops/op_replace_zero.py +13 -0
- onnx_diagnostic/reference/ops/op_rotary.py +19 -0
- onnx_diagnostic/reference/ops/op_scan.py +65 -0
- onnx_diagnostic/reference/ops/op_scatter_elements.py +107 -0
- onnx_diagnostic/reference/ops/op_scatternd_of_shape.py +22 -0
- onnx_diagnostic/reference/ops/op_simplified_layer_normalization.py +8 -0
- onnx_diagnostic/reference/ops/op_skip_layer_normalization.py +13 -0
- onnx_diagnostic/reference/ops/op_slice.py +20 -0
- onnx_diagnostic/reference/ops/op_transpose_cast.py +16 -0
- onnx_diagnostic/reference/ops/op_tri_matrix.py +17 -0
- onnx_diagnostic/reference/ort_evaluator.py +652 -0
- onnx_diagnostic/reference/quantized_tensor.py +46 -0
- onnx_diagnostic/reference/report_results_comparison.py +95 -0
- onnx_diagnostic/reference/torch_evaluator.py +669 -0
- onnx_diagnostic/reference/torch_ops/__init__.py +56 -0
- onnx_diagnostic/reference/torch_ops/_op_run.py +335 -0
- onnx_diagnostic/reference/torch_ops/access_ops.py +94 -0
- onnx_diagnostic/reference/torch_ops/binary_ops.py +108 -0
- onnx_diagnostic/reference/torch_ops/controlflow_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/generator_ops.py +36 -0
- onnx_diagnostic/reference/torch_ops/nn_ops.py +196 -0
- onnx_diagnostic/reference/torch_ops/other_ops.py +106 -0
- onnx_diagnostic/reference/torch_ops/reduce_ops.py +130 -0
- onnx_diagnostic/reference/torch_ops/sequence_ops.py +65 -0
- onnx_diagnostic/reference/torch_ops/shape_ops.py +121 -0
- onnx_diagnostic/reference/torch_ops/unary_ops.py +93 -0
- onnx_diagnostic/tasks/__init__.py +90 -0
- onnx_diagnostic/tasks/automatic_speech_recognition.py +188 -0
- onnx_diagnostic/tasks/data/__init__.py +13 -0
- onnx_diagnostic/tasks/data/dummies_imagetext2text_generation_gemma3.onnx +0 -0
- onnx_diagnostic/tasks/feature_extraction.py +162 -0
- onnx_diagnostic/tasks/fill_mask.py +89 -0
- onnx_diagnostic/tasks/image_classification.py +144 -0
- onnx_diagnostic/tasks/image_text_to_text.py +581 -0
- onnx_diagnostic/tasks/image_to_video.py +127 -0
- onnx_diagnostic/tasks/mask_generation.py +143 -0
- onnx_diagnostic/tasks/mixture_of_expert.py +79 -0
- onnx_diagnostic/tasks/object_detection.py +134 -0
- onnx_diagnostic/tasks/sentence_similarity.py +89 -0
- onnx_diagnostic/tasks/summarization.py +227 -0
- onnx_diagnostic/tasks/text2text_generation.py +230 -0
- onnx_diagnostic/tasks/text_classification.py +89 -0
- onnx_diagnostic/tasks/text_generation.py +352 -0
- onnx_diagnostic/tasks/text_to_image.py +95 -0
- onnx_diagnostic/tasks/zero_shot_image_classification.py +128 -0
- onnx_diagnostic/torch_export_patches/__init__.py +21 -0
- onnx_diagnostic/torch_export_patches/eval/__init__.py +725 -0
- onnx_diagnostic/torch_export_patches/eval/model_cases.py +898 -0
- onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1098 -0
- onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +311 -0
- onnx_diagnostic/torch_export_patches/patch_details.py +340 -0
- onnx_diagnostic/torch_export_patches/patch_expressions.py +108 -0
- onnx_diagnostic/torch_export_patches/patch_inputs.py +211 -0
- onnx_diagnostic/torch_export_patches/patch_module.py +1047 -0
- onnx_diagnostic/torch_export_patches/patch_module_helper.py +184 -0
- onnx_diagnostic/torch_export_patches/patches/__init__.py +0 -0
- onnx_diagnostic/torch_export_patches/patches/patch_torch.py +1090 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +2139 -0
- onnx_diagnostic/torch_export_patches/serialization/__init__.py +46 -0
- onnx_diagnostic/torch_export_patches/serialization/diffusers_impl.py +34 -0
- onnx_diagnostic/torch_export_patches/serialization/transformers_impl.py +313 -0
- onnx_diagnostic/torch_models/__init__.py +0 -0
- onnx_diagnostic/torch_models/code_sample.py +343 -0
- onnx_diagnostic/torch_models/hghub/__init__.py +1 -0
- onnx_diagnostic/torch_models/hghub/hub_api.py +422 -0
- onnx_diagnostic/torch_models/hghub/hub_data.py +234 -0
- onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +4905 -0
- onnx_diagnostic/torch_models/hghub/model_inputs.py +388 -0
- onnx_diagnostic/torch_models/hghub/model_specific.py +76 -0
- onnx_diagnostic/torch_models/llms.py +2 -0
- onnx_diagnostic/torch_models/untrained/__init__.py +0 -0
- onnx_diagnostic/torch_models/untrained/llm_phi2.py +113 -0
- onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +76 -0
- onnx_diagnostic/torch_models/validate.py +2124 -0
- onnx_diagnostic/torch_onnx/__init__.py +0 -0
- onnx_diagnostic/torch_onnx/runtime_info.py +289 -0
- onnx_diagnostic/torch_onnx/sbs.py +440 -0
- onnx_diagnostic-0.8.0.dist-info/METADATA +213 -0
- onnx_diagnostic-0.8.0.dist-info/RECORD +132 -0
- onnx_diagnostic-0.8.0.dist-info/WHEEL +5 -0
- onnx_diagnostic-0.8.0.dist-info/licenses/LICENSE.txt +19 -0
- onnx_diagnostic-0.8.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,440 @@
|
|
|
1
|
+
from typing import Any, Dict, Iterator, Optional, Tuple, Union
|
|
2
|
+
import onnx
|
|
3
|
+
import torch
|
|
4
|
+
from ..helpers import string_type, string_diff, max_diff
|
|
5
|
+
from ..helpers.onnx_helper import to_array_extended
|
|
6
|
+
from ..helpers.torch_helper import to_numpy
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def validate_fx_tensor(
|
|
10
|
+
node: torch.fx.Node, tensor: torch.Tensor, expected_shape: Tuple[Any, ...]
|
|
11
|
+
) -> None:
|
|
12
|
+
"""
|
|
13
|
+
Validates the shape of tensor is expected.
|
|
14
|
+
|
|
15
|
+
:param node: node
|
|
16
|
+
:param tensor: tensor
|
|
17
|
+
:param expected_shape: expected shape
|
|
18
|
+
"""
|
|
19
|
+
assert len(tensor.shape) == len(expected_shape), (
|
|
20
|
+
f"Shape mismatch, got {tensor.shape} expected {expected_shape}, "
|
|
21
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
22
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
23
|
+
f"node.meta={node.meta}"
|
|
24
|
+
)
|
|
25
|
+
for a, b in zip(tensor.shape, expected_shape):
|
|
26
|
+
assert not isinstance(b, int) or a == b or {a, b} == {0, 1}, (
|
|
27
|
+
f"Dimension mismatch, got {tensor.shape} expected {expected_shape}, "
|
|
28
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
29
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
30
|
+
f"node.meta={node.meta}"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def validate_fx_outputs(node: torch.fx.Node, outputs: Tuple[Any, ...]) -> None:
|
|
35
|
+
"""
|
|
36
|
+
Validates the outputs of a node using metadata stored in the node.
|
|
37
|
+
|
|
38
|
+
:param node: node
|
|
39
|
+
:param outputs: outputs
|
|
40
|
+
"""
|
|
41
|
+
if "val" not in node.meta:
|
|
42
|
+
return
|
|
43
|
+
if isinstance(outputs, torch.Tensor):
|
|
44
|
+
validate_fx_tensor(node, outputs, node.meta["val"].shape)
|
|
45
|
+
return
|
|
46
|
+
if isinstance(outputs, (tuple, list)):
|
|
47
|
+
assert isinstance(node.meta["val"], (list, tuple)), (
|
|
48
|
+
f"Unexpected type {string_type(node.meta['val'])} for node.meta['val'], "
|
|
49
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
50
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
51
|
+
f"node.meta={node.meta}"
|
|
52
|
+
)
|
|
53
|
+
assert len(outputs) == len(node.meta["val"]), (
|
|
54
|
+
f"Length mismatch, got {len(outputs)} expected {len(node.meta['val'])}, "
|
|
55
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
56
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
57
|
+
f"node.meta={node.meta}"
|
|
58
|
+
)
|
|
59
|
+
for a, b in zip(outputs, node.meta["val"]):
|
|
60
|
+
validate_fx_tensor(node, a, b.shape)
|
|
61
|
+
return
|
|
62
|
+
if isinstance(outputs, int):
|
|
63
|
+
assert (
|
|
64
|
+
isinstance(node.meta["val"], (torch.SymInt, torch.SymBool, torch.SymFloat))
|
|
65
|
+
or outputs == node.meta["val"]
|
|
66
|
+
), (
|
|
67
|
+
f"Int mismatch, got {outputs} expected {node.meta['val']}, "
|
|
68
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
69
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
70
|
+
f"node.meta={node.meta}"
|
|
71
|
+
)
|
|
72
|
+
return
|
|
73
|
+
if outputs is None:
|
|
74
|
+
assert node.meta["val"] is None, (
|
|
75
|
+
f"None mismatch, got {outputs} expected {node.meta['val']}, "
|
|
76
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
77
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
78
|
+
f"node.meta={node.meta}"
|
|
79
|
+
)
|
|
80
|
+
return
|
|
81
|
+
raise NotImplementedError(
|
|
82
|
+
f"Validation for output type {type(outputs)} is not implemented, "
|
|
83
|
+
f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
|
|
84
|
+
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
85
|
+
f"node.meta={node.meta}"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def run_fx_node(
|
|
90
|
+
node: torch.fx.Node, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
|
|
91
|
+
) -> Tuple[Any, ...]:
|
|
92
|
+
"""
|
|
93
|
+
Executes a node
|
|
94
|
+
|
|
95
|
+
:param node: runs a node
|
|
96
|
+
:param args: unnamed inputs to the node
|
|
97
|
+
:param kwargs: named inputs to the node
|
|
98
|
+
:return: results
|
|
99
|
+
"""
|
|
100
|
+
if node.op == "output":
|
|
101
|
+
assert len(args) == 1 and not kwargs, (
|
|
102
|
+
f"Unexpected inputs: args={string_type(args, limit=20)} "
|
|
103
|
+
f"kwargs={string_type(kwargs, limit=20)}"
|
|
104
|
+
)
|
|
105
|
+
return args
|
|
106
|
+
if node.op == "call_function":
|
|
107
|
+
assert callable(node.target), f"{node.target!r} not callable in node {node!r}"
|
|
108
|
+
outputs = node.target(*args, **(kwargs or {}))
|
|
109
|
+
validate_fx_outputs(node, outputs)
|
|
110
|
+
return outputs
|
|
111
|
+
raise NotImplementedError(
|
|
112
|
+
f"node.op={node.op!r} is not implemented, node.name={node.name!r}"
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _pick_result(torch_results: Dict[str, Any], ref: Any) -> Any:
|
|
117
|
+
"See :func:`prepare_args_kwargs`."
|
|
118
|
+
if isinstance(ref, torch.fx.Node):
|
|
119
|
+
return torch_results[ref.name]
|
|
120
|
+
if isinstance(ref, list):
|
|
121
|
+
return [_pick_result(torch_results, n) for n in ref]
|
|
122
|
+
if isinstance(ref, tuple):
|
|
123
|
+
return tuple(_pick_result(torch_results, n) for n in ref)
|
|
124
|
+
if isinstance(ref, dict):
|
|
125
|
+
return {k: _pick_result(torch_results, v) for k, v in ref.items()}
|
|
126
|
+
if isinstance(ref, (bool, int, float, str, torch.device, torch.dtype)):
|
|
127
|
+
return ref
|
|
128
|
+
if ref is None:
|
|
129
|
+
return None
|
|
130
|
+
raise NotImplementedError(f"Unable to process args type {type(ref)}")
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def prepare_args_kwargs(
|
|
134
|
+
torch_results: Dict[str, Any], node: torch.fx.Node
|
|
135
|
+
) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
|
|
136
|
+
"""
|
|
137
|
+
Prepares args and kwargs before executing a fx node.
|
|
138
|
+
|
|
139
|
+
:param torch_results: existing results
|
|
140
|
+
:param node: node to execute
|
|
141
|
+
:return: new args and kwargs
|
|
142
|
+
"""
|
|
143
|
+
new_args = _pick_result(torch_results, node.args)
|
|
144
|
+
new_kwargs = _pick_result(torch_results, node.kwargs)
|
|
145
|
+
return new_args, new_kwargs
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def run_aligned(
|
|
149
|
+
ep: torch.export.ExportedProgram,
|
|
150
|
+
onx: Union[onnx.ModelProto, onnx.FunctionProto],
|
|
151
|
+
args: Tuple[torch.Tensor, ...],
|
|
152
|
+
check_conversion_cls: Union[Dict[str, Any], type],
|
|
153
|
+
kwargs: Optional[Dict[str, Any]] = None,
|
|
154
|
+
verbose: int = 0,
|
|
155
|
+
) -> Iterator[Tuple[Any, ...]]:
|
|
156
|
+
"""
|
|
157
|
+
Runs in parallel both the exported program
|
|
158
|
+
and the onnx proto and looks for discrepancies.
|
|
159
|
+
The function does match on result names so it assumes
|
|
160
|
+
the exported program and the onnx model have the same names
|
|
161
|
+
for equivalent results.
|
|
162
|
+
|
|
163
|
+
:param ep: exported program
|
|
164
|
+
:param onx: model or function proto
|
|
165
|
+
:param args: input args
|
|
166
|
+
:param check_conversion_cls: defines the runtime to use for this task
|
|
167
|
+
:param kwargs: input kwargs
|
|
168
|
+
:param verbose: verbosity level
|
|
169
|
+
:return: a list of tuples containing the results, they come in tuple,
|
|
170
|
+
|
|
171
|
+
Example:
|
|
172
|
+
|
|
173
|
+
.. runpython::
|
|
174
|
+
:showcode:
|
|
175
|
+
:warningout: UserWarning
|
|
176
|
+
|
|
177
|
+
import pprint
|
|
178
|
+
import pandas
|
|
179
|
+
import torch
|
|
180
|
+
from onnx_diagnostic.reference import (
|
|
181
|
+
# This can be replace by any runtime taking NodeProto as an input.
|
|
182
|
+
ExtendedReferenceEvaluator as ReferenceEvaluator,
|
|
183
|
+
)
|
|
184
|
+
from onnx_diagnostic.torch_onnx.sbs import run_aligned
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
class Model(torch.nn.Module):
|
|
188
|
+
def forward(self, x):
|
|
189
|
+
ry = x.abs()
|
|
190
|
+
rz = ry.exp()
|
|
191
|
+
rw = rz + 1
|
|
192
|
+
ru = rw.log() + rw
|
|
193
|
+
return ru
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def post_process(obs):
|
|
197
|
+
dobs = dict(zip(["ep_id_node", "onnx_id_node", "ep_name", "onnx_name"], obs))
|
|
198
|
+
dobs["err_abs"] = obs[-1]["abs"]
|
|
199
|
+
dobs["err_rel"] = obs[-1]["rel"]
|
|
200
|
+
return dobs
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
x = torch.randn((5, 4))
|
|
204
|
+
Model()(x) # to make sure the model is running
|
|
205
|
+
ep = torch.export.export(
|
|
206
|
+
Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},)
|
|
207
|
+
)
|
|
208
|
+
onx = torch.onnx.export(
|
|
209
|
+
Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},)
|
|
210
|
+
).model_proto
|
|
211
|
+
results = list(
|
|
212
|
+
map(
|
|
213
|
+
post_process,
|
|
214
|
+
run_aligned(
|
|
215
|
+
ep,
|
|
216
|
+
onx,
|
|
217
|
+
(x,),
|
|
218
|
+
check_conversion_cls=dict(cls=ReferenceEvaluator, atol=1e-5, rtol=1e-5),
|
|
219
|
+
verbose=1,
|
|
220
|
+
),
|
|
221
|
+
),
|
|
222
|
+
)
|
|
223
|
+
print("------------")
|
|
224
|
+
print("final results")
|
|
225
|
+
df = pandas.DataFrame(results)
|
|
226
|
+
print(df)
|
|
227
|
+
"""
|
|
228
|
+
assert not kwargs, f"Not implemented when kwargs={string_type(kwargs,with_shape=True)}"
|
|
229
|
+
cls, atol, rtol = (
|
|
230
|
+
(
|
|
231
|
+
check_conversion_cls["cls"],
|
|
232
|
+
check_conversion_cls["atol"],
|
|
233
|
+
check_conversion_cls["rtol"],
|
|
234
|
+
)
|
|
235
|
+
if isinstance(check_conversion_cls, dict)
|
|
236
|
+
else (check_conversion_cls, None, None)
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# retrieve the positions
|
|
240
|
+
positions: Dict[str, Any] = {}
|
|
241
|
+
for i, node in enumerate(ep.graph.nodes):
|
|
242
|
+
if isinstance(node.name, str):
|
|
243
|
+
positions[node.name] = dict(fx=i)
|
|
244
|
+
else:
|
|
245
|
+
for n in node.name:
|
|
246
|
+
positions[n] = dict(fx=i)
|
|
247
|
+
|
|
248
|
+
for i, node in enumerate(onx.graph.node):
|
|
249
|
+
for n in node.output:
|
|
250
|
+
if n in positions:
|
|
251
|
+
positions[n]["onnx"] = i
|
|
252
|
+
else:
|
|
253
|
+
positions[n] = dict(onnx=i)
|
|
254
|
+
|
|
255
|
+
onnx_results: Dict[str, Any] = {}
|
|
256
|
+
for init in onx.graph.initializer: # type: ignore
|
|
257
|
+
positions[init.name] = -1
|
|
258
|
+
onnx_results[init.name] = to_array_extended(init)
|
|
259
|
+
param_name = f"p_{init.name.replace('.', '_')}"
|
|
260
|
+
if param_name == init.name:
|
|
261
|
+
continue
|
|
262
|
+
assert param_name not in onnx_results, (
|
|
263
|
+
f"Some confusion may happen because {init.name!r} -> {param_name!r} "
|
|
264
|
+
f"and onnx_results has {sorted(onnx_results)}"
|
|
265
|
+
)
|
|
266
|
+
onnx_results[param_name] = onnx_results[init.name]
|
|
267
|
+
|
|
268
|
+
torch_results: Dict[str, Any] = {
|
|
269
|
+
k: torch.from_numpy(v.copy())
|
|
270
|
+
for k, v in onnx_results.items()
|
|
271
|
+
if not k.startswith("init")
|
|
272
|
+
}
|
|
273
|
+
last_position = 0
|
|
274
|
+
torch_output_names = None
|
|
275
|
+
for node in ep.graph.nodes:
|
|
276
|
+
if node.op == "output":
|
|
277
|
+
torch_output_names = [n.name for n in node.args[0]]
|
|
278
|
+
onnx_outputs_names = [o.name for o in onx.graph.output]
|
|
279
|
+
assert torch_output_names is not None and len(torch_output_names) == len(
|
|
280
|
+
onnx_outputs_names
|
|
281
|
+
), (
|
|
282
|
+
f"Unexpected number of outputs, torch_output_names={torch_output_names}, "
|
|
283
|
+
f"onnx_outputs_names={onnx_outputs_names}"
|
|
284
|
+
)
|
|
285
|
+
mapping_onnx_to_torch = dict(zip(onnx_outputs_names, torch_output_names))
|
|
286
|
+
|
|
287
|
+
if verbose:
|
|
288
|
+
for k, v in torch_results.items():
|
|
289
|
+
print(
|
|
290
|
+
f"[run_aligned] +torch-cst: {k}: "
|
|
291
|
+
f"{string_type(v, with_shape=True, with_min_max=True)}"
|
|
292
|
+
)
|
|
293
|
+
for k, v in onnx_results.items():
|
|
294
|
+
print(
|
|
295
|
+
f"[run_aligned] +onnx-init: {k}: "
|
|
296
|
+
f"{string_type(v, with_shape=True, with_min_max=True)}"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
for inp, v in zip(onx.graph.input, args):
|
|
300
|
+
onnx_results[inp.name] = to_numpy(v)
|
|
301
|
+
if verbose:
|
|
302
|
+
print(
|
|
303
|
+
f"[run_aligned] +onnx-input: {inp.name}: "
|
|
304
|
+
f"{string_type(v, with_shape=True, with_min_max=True)}"
|
|
305
|
+
)
|
|
306
|
+
|
|
307
|
+
for i, node in enumerate(ep.graph.nodes):
|
|
308
|
+
if verbose:
|
|
309
|
+
if node.op == "call_function":
|
|
310
|
+
print(
|
|
311
|
+
f"[run_aligned] run ep.graph.nodes[{i}]: "
|
|
312
|
+
f"{node.op}[{node.target}] -> {node.name!r}"
|
|
313
|
+
)
|
|
314
|
+
else:
|
|
315
|
+
print(f"[run_aligned] run ep.graph.nodes[{i}]: {node.op} -> {node.name!r}")
|
|
316
|
+
|
|
317
|
+
if node.op == "placeholder":
|
|
318
|
+
if node.name in onnx_results:
|
|
319
|
+
torch_results[node.name] = torch.from_numpy(onnx_results[node.name].copy())
|
|
320
|
+
if verbose:
|
|
321
|
+
t = torch_results[node.name]
|
|
322
|
+
print(
|
|
323
|
+
f"[run_aligned] +torch {node.name}="
|
|
324
|
+
f"{string_type(t, with_shape=True, with_min_max=True)}"
|
|
325
|
+
)
|
|
326
|
+
continue
|
|
327
|
+
raise AssertionError(
|
|
328
|
+
f"unable to process node {node.op} -> {node.name!r} "
|
|
329
|
+
f"not in {sorted(onnx_results)}, len(args)={len(args)}, "
|
|
330
|
+
f"onx.graph.input={[i.name for i in onx.graph.input]}"
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
outputs = [node.name] if isinstance(node.name, str) else list(node.name)
|
|
334
|
+
args, kwargs = prepare_args_kwargs(torch_results, node)
|
|
335
|
+
new_outputs = run_fx_node(node, args, kwargs)
|
|
336
|
+
if isinstance(new_outputs, (torch.Tensor, int, float, list)):
|
|
337
|
+
new_outputs = (new_outputs,)
|
|
338
|
+
|
|
339
|
+
if new_outputs is None:
|
|
340
|
+
# Probably an assert.
|
|
341
|
+
continue
|
|
342
|
+
|
|
343
|
+
for k, v in zip(outputs, new_outputs):
|
|
344
|
+
torch_results[k] = v
|
|
345
|
+
if verbose:
|
|
346
|
+
for k, v in zip(outputs, new_outputs):
|
|
347
|
+
print(
|
|
348
|
+
f"[run_aligned] +torch {k}="
|
|
349
|
+
f"{string_type(v, with_shape=True, with_min_max=True)}"
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
max_pos = -2
|
|
353
|
+
for n in outputs:
|
|
354
|
+
if n in positions and "onnx" in positions[n]:
|
|
355
|
+
max_pos = max(max_pos, positions[n]["onnx"])
|
|
356
|
+
if max_pos == -2:
|
|
357
|
+
# we skip.
|
|
358
|
+
continue
|
|
359
|
+
|
|
360
|
+
for i_onnx in range(last_position, max_pos + 1):
|
|
361
|
+
node = onx.graph.node[i_onnx]
|
|
362
|
+
if verbose:
|
|
363
|
+
print(
|
|
364
|
+
f"[run_aligned] run onx.graph.node[{i_onnx}]: "
|
|
365
|
+
f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
|
|
366
|
+
)
|
|
367
|
+
ref = cls(node)
|
|
368
|
+
feeds = {k: onnx_results[k] for k in node.input}
|
|
369
|
+
res = ref.run(None, feeds)
|
|
370
|
+
for o, r in zip(node.output, res):
|
|
371
|
+
onnx_results[o] = r
|
|
372
|
+
if verbose:
|
|
373
|
+
print(
|
|
374
|
+
f"[run_aligned] +onnx {o}="
|
|
375
|
+
f"{string_type(r, with_shape=True, with_min_max=True)}"
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
to = mapping_onnx_to_torch.get(o, o)
|
|
379
|
+
if to in torch_results:
|
|
380
|
+
d = max_diff(torch_results[to], r)
|
|
381
|
+
if verbose:
|
|
382
|
+
if o == to:
|
|
383
|
+
print(f"[run_aligned] =common results {to}: {string_diff(d)}")
|
|
384
|
+
else:
|
|
385
|
+
print(f"[run_aligned] =common results {to}/{o}: {string_diff(d)}")
|
|
386
|
+
if not (
|
|
387
|
+
atol is None
|
|
388
|
+
or rtol is None
|
|
389
|
+
or (d["abs"] <= atol and d["rel"] <= rtol)
|
|
390
|
+
):
|
|
391
|
+
skw = dict(with_shape=True, with_min_max=True)
|
|
392
|
+
raise ValueError(
|
|
393
|
+
f"discrepancies detected for results [{to}/{o}]: "
|
|
394
|
+
f"{string_diff(d)}"
|
|
395
|
+
f"\n-- torch_results: {string_type(torch_results[to], **skw)}"
|
|
396
|
+
f"\n-- onnx_results: {string_type(r, **skw)}"
|
|
397
|
+
f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}"
|
|
398
|
+
)
|
|
399
|
+
yield (i, i_onnx, o, to, d)
|
|
400
|
+
|
|
401
|
+
last_position = max_pos + 1
|
|
402
|
+
|
|
403
|
+
# complete the execution of the onnx graph
|
|
404
|
+
for i_onnx in range(last_position, len(onx.graph.node)):
|
|
405
|
+
node = onx.graph.node[i_onnx]
|
|
406
|
+
if verbose:
|
|
407
|
+
print(
|
|
408
|
+
f"[run_aligned] run onx.graph.node[{i_onnx}]: "
|
|
409
|
+
f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
|
|
410
|
+
)
|
|
411
|
+
ref = cls(node)
|
|
412
|
+
feeds = {k: onnx_results[k] for k in node.input}
|
|
413
|
+
res = ref.run(None, feeds)
|
|
414
|
+
for o, r in zip(node.output, res):
|
|
415
|
+
onnx_results[o] = r
|
|
416
|
+
if verbose:
|
|
417
|
+
print(
|
|
418
|
+
f"[run_aligned] +onnx {o}="
|
|
419
|
+
f"{string_type(r, with_shape=True, with_min_max=True)}"
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
to = mapping_onnx_to_torch.get(o, o)
|
|
423
|
+
if to in torch_results:
|
|
424
|
+
d = max_diff(torch_results[to], r)
|
|
425
|
+
if verbose:
|
|
426
|
+
if o == to:
|
|
427
|
+
print(f"[run_aligned] =common results* {to}: {string_diff(d)}")
|
|
428
|
+
else:
|
|
429
|
+
print(f"[run_aligned] =common results* {to}/{o}: {string_diff(d)}")
|
|
430
|
+
if not (
|
|
431
|
+
atol is None or rtol is None or (d["abs"] <= atol and d["rel"] <= rtol)
|
|
432
|
+
):
|
|
433
|
+
skw = dict(with_shape=True, with_min_max=True)
|
|
434
|
+
raise ValueError(
|
|
435
|
+
f"discrepancies detected for results* [{to}/{o}]: {string_diff(d)}"
|
|
436
|
+
f"\n-- torch_results: {string_type(torch_results[to], **skw)}"
|
|
437
|
+
f"\n-- onnx_results: {string_type(r, **skw)}"
|
|
438
|
+
f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}"
|
|
439
|
+
)
|
|
440
|
+
yield (i, i_onnx, o, to, d)
|
|
@@ -0,0 +1,213 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: onnx-diagnostic
|
|
3
|
+
Version: 0.8.0
|
|
4
|
+
Summary: Tools to help converting pytorch models into ONNX.
|
|
5
|
+
Home-page: https://github.com/sdpython/onnx-diagnostic
|
|
6
|
+
Author: Xavier Dupré
|
|
7
|
+
Author-email: Xavier Dupré <xavier.dupre@gmail.com>
|
|
8
|
+
License: MIT
|
|
9
|
+
Project-URL: Homepage, https://sdpython.github.io/doc/onnx-diagnostic/dev/
|
|
10
|
+
Project-URL: Repository, https://github.com/sdpython/onnx-diagnostic/
|
|
11
|
+
Requires-Python: >=3.9
|
|
12
|
+
Description-Content-Type: text/x-rst
|
|
13
|
+
License-File: LICENSE.txt
|
|
14
|
+
Dynamic: author
|
|
15
|
+
Dynamic: home-page
|
|
16
|
+
Dynamic: license-file
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
.. image:: https://github.com/sdpython/onnx-diagnostic/raw/main/_doc/_static/logo.png
|
|
20
|
+
:width: 120
|
|
21
|
+
|
|
22
|
+
onnx-diagnostic: investigate onnx models
|
|
23
|
+
========================================
|
|
24
|
+
|
|
25
|
+
.. image:: https://github.com/sdpython/onnx-diagnostic/actions/workflows/documentation.yml/badge.svg
|
|
26
|
+
:target: https://github.com/sdpython/onnx-diagnostic/actions/workflows/documentation.yml
|
|
27
|
+
|
|
28
|
+
.. image:: https://badge.fury.io/py/onnx-diagnostic.svg
|
|
29
|
+
:target: http://badge.fury.io/py/onnx-diagnostic
|
|
30
|
+
|
|
31
|
+
.. image:: https://img.shields.io/badge/license-MIT-blue.svg
|
|
32
|
+
:alt: MIT License
|
|
33
|
+
:target: https://opensource.org/license/MIT/
|
|
34
|
+
|
|
35
|
+
.. image:: https://img.shields.io/github/repo-size/sdpython/onnx-diagnostic
|
|
36
|
+
:target: https://github.com/sdpython/onnx-diagnostic/
|
|
37
|
+
:alt: size
|
|
38
|
+
|
|
39
|
+
.. image:: https://img.shields.io/badge/code%20style-black-000000.svg
|
|
40
|
+
:target: https://github.com/psf/black
|
|
41
|
+
|
|
42
|
+
.. image:: https://codecov.io/gh/sdpython/onnx-diagnostic/graph/badge.svg?token=91T5ZVIP96
|
|
43
|
+
:target: https://codecov.io/gh/sdpython/onnx-diagnostic
|
|
44
|
+
|
|
45
|
+
The main feature is about `patches <https://github.com/sdpython/onnx-diagnostic/tree/main/onnx_diagnostic/torch_export_patches>`_:
|
|
46
|
+
it helps exporting **pytorch models into ONNX**, mostly designed for LLMs using dynamic caches.
|
|
47
|
+
Patches can be enabled as follows:
|
|
48
|
+
|
|
49
|
+
.. code-block:: python
|
|
50
|
+
|
|
51
|
+
from onnx_diagnostic.torch_export_patches import torch_export_patches
|
|
52
|
+
|
|
53
|
+
with torch_export_patches(patch_transformers=True) as f:
|
|
54
|
+
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
|
|
55
|
+
# ...
|
|
56
|
+
|
|
57
|
+
Dynamic shapes are difficult to guess for caches, one function
|
|
58
|
+
returns a structure defining all dimensions as dynamic.
|
|
59
|
+
You need then to remove those which are not dynamic in your model.
|
|
60
|
+
|
|
61
|
+
.. code-block:: python
|
|
62
|
+
|
|
63
|
+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
|
|
64
|
+
|
|
65
|
+
dynamic_shapes = all_dynamic_shapes_from_inputs(cache)
|
|
66
|
+
|
|
67
|
+
It also implements tools to investigate, validate exported models (ExportedProgramm, ONNXProgram, ...).
|
|
68
|
+
See `documentation of onnx-diagnostic <https://sdpython.github.io/doc/onnx-diagnostic/dev/>`_ and
|
|
69
|
+
`torch_export_patches <https://sdpython.github.io/doc/onnx-diagnostic/dev/api/torch_export_patches/index.html#onnx_diagnostic.torch_export_patches.torch_export_patches>`_.
|
|
70
|
+
|
|
71
|
+
Getting started
|
|
72
|
+
+++++++++++++++
|
|
73
|
+
|
|
74
|
+
::
|
|
75
|
+
|
|
76
|
+
git clone https://github.com/sdpython/onnx-diagnostic.git
|
|
77
|
+
cd onnx-diagnostic
|
|
78
|
+
pip install -e . -v
|
|
79
|
+
|
|
80
|
+
or
|
|
81
|
+
|
|
82
|
+
::
|
|
83
|
+
|
|
84
|
+
pip install onnx-diagnostic
|
|
85
|
+
|
|
86
|
+
Enlightening Examples
|
|
87
|
+
+++++++++++++++++++++
|
|
88
|
+
|
|
89
|
+
**Where to start to export a model**
|
|
90
|
+
|
|
91
|
+
* `Export microsoft/phi-2
|
|
92
|
+
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_phi2.html>`_
|
|
93
|
+
|
|
94
|
+
**Torch Export**
|
|
95
|
+
|
|
96
|
+
* `Use DYNAMIC or AUTO when exporting if dynamic shapes has constraints
|
|
97
|
+
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_shapes_auto.html>`_
|
|
98
|
+
* `Find and fix an export issue due to dynamic shapes
|
|
99
|
+
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_locate_issue.html>`_
|
|
100
|
+
* `Export with DynamicCache and guessed dynamic shapes
|
|
101
|
+
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_with_dynamic_cache.html>`_
|
|
102
|
+
* `Steel method forward to guess the dynamic shapes (with Tiny-LLM)
|
|
103
|
+
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_llm.html>`_
|
|
104
|
+
* `Export Tiny-LLM with patches
|
|
105
|
+
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_export_tiny_llm_patched.html>`_
|
|
106
|
+
|
|
107
|
+
**Investigate ONNX models**
|
|
108
|
+
|
|
109
|
+
* `Find where a model is failing by running submodels
|
|
110
|
+
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_failing_model_extract.html>`_
|
|
111
|
+
* `Intermediate results with (ONNX) ReferenceEvaluator
|
|
112
|
+
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_failing_reference_evaluator.html>`_
|
|
113
|
+
* `Intermediate results with onnxruntime
|
|
114
|
+
<https://sdpython.github.io/doc/onnx-diagnostic/dev/auto_examples/plot_failing_onnxruntime_evaluator.html>`_
|
|
115
|
+
|
|
116
|
+
Snapshot of usefuls tools
|
|
117
|
+
+++++++++++++++++++++++++
|
|
118
|
+
|
|
119
|
+
**torch_export_patches**
|
|
120
|
+
|
|
121
|
+
.. code-block:: python
|
|
122
|
+
|
|
123
|
+
from onnx_diagnostic.torch_export_patches import torch_export_patches
|
|
124
|
+
|
|
125
|
+
with torch_export_patches(patch_transformers=True) as f:
|
|
126
|
+
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
|
|
127
|
+
# ...
|
|
128
|
+
|
|
129
|
+
**all_dynamic_shapes_from_inputs**
|
|
130
|
+
|
|
131
|
+
.. code-block:: python
|
|
132
|
+
|
|
133
|
+
from onnx_diagnostic.export.shape_helper import all_dynamic_shapes_from_inputs
|
|
134
|
+
|
|
135
|
+
dynamic_shapes = all_dynamic_shapes_from_inputs(cache)
|
|
136
|
+
|
|
137
|
+
**torch_export_rewrite**
|
|
138
|
+
|
|
139
|
+
.. code-block:: python
|
|
140
|
+
|
|
141
|
+
from onnx_diagnostic.torch_export_patches import torch_export_rewrite
|
|
142
|
+
|
|
143
|
+
with torch_export_rewrite(rewrite=[Model.forward]) as f:
|
|
144
|
+
ep = torch.export.export(model, args, kwargs=kwargs, dynamic_shapes=dynamic_shapes)
|
|
145
|
+
# ...
|
|
146
|
+
|
|
147
|
+
**string_type**
|
|
148
|
+
|
|
149
|
+
.. code-block:: python
|
|
150
|
+
|
|
151
|
+
import torch
|
|
152
|
+
from onnx_diagnostic.helpers import string_type
|
|
153
|
+
|
|
154
|
+
inputs = (
|
|
155
|
+
torch.rand((3, 4), dtype=torch.float16),
|
|
156
|
+
[torch.rand((5, 6), dtype=torch.float16), torch.rand((5, 6, 7), dtype=torch.float16)],
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# with shapes
|
|
160
|
+
print(string_type(inputs, with_shape=True))
|
|
161
|
+
|
|
162
|
+
::
|
|
163
|
+
|
|
164
|
+
>>> (T10s3x4,#2[T10s5x6,T10s5x6x7])
|
|
165
|
+
|
|
166
|
+
**onnx_dtype_name**
|
|
167
|
+
|
|
168
|
+
.. code-block:: python
|
|
169
|
+
|
|
170
|
+
import onnx
|
|
171
|
+
from onnx_diagnostic.helpers.onnx_helper import onnx_dtype_name
|
|
172
|
+
|
|
173
|
+
itype = onnx.TensorProto.BFLOAT16
|
|
174
|
+
print(onnx_dtype_name(itype))
|
|
175
|
+
print(onnx_dtype_name(7))
|
|
176
|
+
|
|
177
|
+
::
|
|
178
|
+
|
|
179
|
+
>>> BFLOAT16
|
|
180
|
+
>>> INT64
|
|
181
|
+
|
|
182
|
+
**max_diff**
|
|
183
|
+
|
|
184
|
+
.. code-block:: python
|
|
185
|
+
|
|
186
|
+
import torch
|
|
187
|
+
from onnx_diagnostic.helpers import max_diff
|
|
188
|
+
|
|
189
|
+
print(
|
|
190
|
+
max_diff(
|
|
191
|
+
(torch.Tensor([1, 2]), (torch.Tensor([1, 2]),)),
|
|
192
|
+
(torch.Tensor([1, 2]), (torch.Tensor([1, 2]),)),
|
|
193
|
+
)
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
::
|
|
197
|
+
|
|
198
|
+
>>> {"abs": 0.0, "rel": 0.0, "sum": 0.0, "n": 4.0, "dnan": 0.0}s
|
|
199
|
+
|
|
200
|
+
**guess_dynamic_shapes**
|
|
201
|
+
|
|
202
|
+
.. code-block:: python
|
|
203
|
+
|
|
204
|
+
inputs = [
|
|
205
|
+
(torch.randn((5, 6)), torch.randn((1, 6))),
|
|
206
|
+
(torch.randn((7, 8)), torch.randn((1, 8))),
|
|
207
|
+
]
|
|
208
|
+
ds = ModelInputs(model, inputs).guess_dynamic_shapes(auto="dim")
|
|
209
|
+
print(ds)
|
|
210
|
+
|
|
211
|
+
::
|
|
212
|
+
|
|
213
|
+
>>> (({0: 'dim_0I0', 1: 'dim_0I1'}, {1: 'dim_1I1'}), {})
|