onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnx_diagnostic/__init__.py +1 -1
- onnx_diagnostic/_command_lines_parser.py +412 -12
- onnx_diagnostic/export/api.py +111 -8
- onnx_diagnostic/export/control_flow.py +48 -345
- onnx_diagnostic/export/control_flow_onnx.py +528 -0
- onnx_diagnostic/export/control_flow_research.py +12 -7
- onnx_diagnostic/export/onnx_plug.py +531 -0
- onnx_diagnostic/ext_test_case.py +163 -48
- onnx_diagnostic/helpers/cache_helper.py +1 -1
- onnx_diagnostic/helpers/dot_helper.py +222 -0
- onnx_diagnostic/helpers/helper.py +108 -37
- onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
- onnx_diagnostic/helpers/model_builder_helper.py +27 -0
- onnx_diagnostic/helpers/onnx_helper.py +531 -6
- onnx_diagnostic/helpers/ort_session.py +45 -19
- onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
- onnx_diagnostic/helpers/torch_helper.py +131 -8
- onnx_diagnostic/reference/ort_evaluator.py +228 -46
- onnx_diagnostic/tasks/feature_extraction.py +15 -14
- onnx_diagnostic/tasks/summarization.py +72 -137
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
- onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
- onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
- onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
- onnx_diagnostic/torch_models/code_sample.py +2 -1
- onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
- onnx_diagnostic/torch_models/validate.py +64 -2
- onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
- onnx_diagnostic/torch_onnx/sbs.py +969 -312
- onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
- {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
|
@@ -1,158 +1,586 @@
|
|
|
1
|
-
|
|
1
|
+
import inspect
|
|
2
|
+
import time
|
|
3
|
+
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Sequence, Tuple, Union
|
|
2
4
|
import onnx
|
|
5
|
+
import onnx.helper as oh
|
|
6
|
+
import numpy as np
|
|
3
7
|
import torch
|
|
4
|
-
from ..helpers import string_type, string_diff, max_diff
|
|
5
|
-
from ..helpers.onnx_helper import
|
|
6
|
-
from ..helpers.torch_helper import
|
|
8
|
+
from ..helpers import string_type, string_diff, max_diff, flatten_object
|
|
9
|
+
from ..helpers.onnx_helper import pretty_onnx
|
|
10
|
+
from ..helpers.torch_helper import (
|
|
11
|
+
to_numpy,
|
|
12
|
+
from_numpy,
|
|
13
|
+
to_tensor,
|
|
14
|
+
torch_dtype_to_onnx_dtype,
|
|
15
|
+
)
|
|
16
|
+
from ..helpers.torch_fx_graph_helper import prepare_args_kwargs, run_fx_node
|
|
17
|
+
from ..reference.ort_evaluator import OnnxList, OnnxruntimeEvaluator
|
|
18
|
+
from .sbs_dataclasses import (
|
|
19
|
+
ReplayConfiguration,
|
|
20
|
+
RunAlignedRecord,
|
|
21
|
+
StatusRunAligned,
|
|
22
|
+
make_torch_inputs,
|
|
23
|
+
)
|
|
7
24
|
|
|
8
25
|
|
|
9
|
-
def
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
26
|
+
def _check_tensor_(use_tensor, name, obj, flip_type=False):
|
|
27
|
+
if flip_type:
|
|
28
|
+
if use_tensor:
|
|
29
|
+
if isinstance(obj, np.ndarray):
|
|
30
|
+
obj = from_numpy(obj)
|
|
31
|
+
else:
|
|
32
|
+
if isinstance(obj, torch.Tensor):
|
|
33
|
+
obj = to_numpy(obj)
|
|
14
34
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
assert
|
|
20
|
-
f"
|
|
21
|
-
f"
|
|
22
|
-
f"node.args={node.args}, node.kwargs={node.kwargs}, "
|
|
23
|
-
f"node.meta={node.meta}"
|
|
35
|
+
assert not use_tensor or isinstance(obj, (torch.Tensor, OnnxList)), (
|
|
36
|
+
f"Unexpected type {type(obj)} for {name!r}. "
|
|
37
|
+
f"use_tensor is True so torch.Tensor is expected."
|
|
38
|
+
)
|
|
39
|
+
assert use_tensor or isinstance(obj, (np.ndarray, OnnxList)), (
|
|
40
|
+
f"Unexpected type {type(obj)} for {name!r}. "
|
|
41
|
+
f"use_tensor is False so np.array is expected."
|
|
24
42
|
)
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
43
|
+
return obj
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _make_node_from_initializer(proto: onnx.TensorProto) -> onnx.NodeProto:
|
|
47
|
+
return oh.make_node("Constant", [], [proto.name], value=proto)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _loop_cmp(
|
|
51
|
+
mapping_onnx_to_torch: Dict[str, str],
|
|
52
|
+
torch_results: Dict[str, torch.Tensor],
|
|
53
|
+
onnx_results: Dict[str, Any],
|
|
54
|
+
onnx_name: str,
|
|
55
|
+
onnx_result: torch.Tensor,
|
|
56
|
+
second_onnx_result: torch.Tensor,
|
|
57
|
+
verbose: int,
|
|
58
|
+
atol: Optional[float],
|
|
59
|
+
rtol: Optional[float],
|
|
60
|
+
i_torch: int,
|
|
61
|
+
i_onnx: int,
|
|
62
|
+
str_kws: Dict[str, bool],
|
|
63
|
+
exc: bool,
|
|
64
|
+
use_tensor: bool,
|
|
65
|
+
) -> Optional[RunAlignedRecord]:
|
|
66
|
+
onnx_results[onnx_name] = _check_tensor_(use_tensor, onnx_name, onnx_result)
|
|
67
|
+
if verbose > 1:
|
|
68
|
+
print(f"[run_aligned-nx] +res: {onnx_name}={string_type(onnx_result, **str_kws)}")
|
|
69
|
+
|
|
70
|
+
to = mapping_onnx_to_torch.get(onnx_name, onnx_name)
|
|
71
|
+
if to in torch_results:
|
|
72
|
+
d = max_diff(torch_results[to], onnx_result, hist=[0.1, 0.01])
|
|
73
|
+
if verbose > 1:
|
|
74
|
+
if onnx_name == to:
|
|
75
|
+
print(f"[run_aligned-==] cmp {to}: {string_diff(d)}")
|
|
76
|
+
else:
|
|
77
|
+
print(f"[run_aligned-~~] cmd {to}/{onnx_name}: {string_diff(d)}")
|
|
78
|
+
if not (atol is None or rtol is None or (d["abs"] <= atol and d["rel"] <= rtol)):
|
|
79
|
+
if exc:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"discrepancies detected for results [{to}/{onnx_name}]: "
|
|
82
|
+
f"{string_diff(d)}"
|
|
83
|
+
f"\n-- onnx_result: {string_type(onnx_result[to], **str_kws)}"
|
|
84
|
+
f"\n-- onnx_results: {string_type(onnx_result, **str_kws)}"
|
|
85
|
+
f"\n-- torch\n{onnx_result[to]}"
|
|
86
|
+
)
|
|
87
|
+
else:
|
|
88
|
+
print(
|
|
89
|
+
f"[run_align-dx] discrepancies {string_diff(d)} - [{to}/{onnx_name}]"
|
|
90
|
+
)
|
|
91
|
+
r = RunAlignedRecord(
|
|
92
|
+
ep_id_node=i_torch,
|
|
93
|
+
onnx_id_node=i_onnx,
|
|
94
|
+
ep_name=to,
|
|
95
|
+
onnx_name=onnx_name,
|
|
96
|
+
ep_shape_type=string_type(torch_results[to], **str_kws),
|
|
97
|
+
onnx_shape_type=string_type(onnx_result, **str_kws),
|
|
31
98
|
)
|
|
99
|
+
r.set_diff(d)
|
|
100
|
+
if second_onnx_result is not None:
|
|
101
|
+
d2 = max_diff(torch_results[to], second_onnx_result, hist=[0.1, 0.01])
|
|
102
|
+
r.set_diff2(d2)
|
|
103
|
+
mapping_onnx_to_torch[onnx_name] = to
|
|
104
|
+
return r
|
|
105
|
+
return None
|
|
32
106
|
|
|
33
107
|
|
|
34
|
-
def
|
|
35
|
-
|
|
36
|
-
|
|
108
|
+
def _duplicated_values(d):
|
|
109
|
+
rev = {}
|
|
110
|
+
for k, v in d.items():
|
|
111
|
+
if v in rev:
|
|
112
|
+
rev[v].append(k)
|
|
113
|
+
else:
|
|
114
|
+
rev[v] = [k]
|
|
115
|
+
res = {k: v for k, v in rev.items() if len(v) > 1}
|
|
116
|
+
final = set()
|
|
117
|
+
for v in res.values():
|
|
118
|
+
final |= set(v)
|
|
119
|
+
return final
|
|
37
120
|
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
121
|
+
|
|
122
|
+
def _validation_nn_functional(
|
|
123
|
+
node: onnx.NodeProto, new_feeds: Dict[str, torch.Tensor], expected: List[torch.Tensor]
|
|
124
|
+
) -> Optional[str]:
|
|
125
|
+
if node.op_type == "Gemm" and len(node.input) == 3:
|
|
126
|
+
atts = {}
|
|
127
|
+
for att in node.attribute:
|
|
128
|
+
if att.name in ("alpha", "beta"):
|
|
129
|
+
atts[att.name] = att.f
|
|
130
|
+
elif att.name in ("transA", "transB"):
|
|
131
|
+
atts[att.name] = att.i
|
|
132
|
+
if atts == {"transB": 1}:
|
|
133
|
+
res = torch.nn.functional.linear(*[new_feeds[i] for i in node.input])
|
|
134
|
+
diff = max_diff(res, expected[0])
|
|
135
|
+
return f"function.linear:{string_diff(diff)}"
|
|
136
|
+
return None
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _loop_onnx_node(
|
|
140
|
+
onx: onnx.ModelProto,
|
|
141
|
+
ep_graph_nodes: List[torch.fx.Node],
|
|
142
|
+
onnx_results: Dict[str, Any],
|
|
143
|
+
mapping_onnx_to_torch: Dict[str, str],
|
|
144
|
+
torch_results: Dict[str, torch.Tensor],
|
|
145
|
+
ep_durations,
|
|
146
|
+
use_tensor: bool,
|
|
147
|
+
i_torch: int,
|
|
148
|
+
i_onnx: int,
|
|
149
|
+
name_to_ep_node: Dict[str, int],
|
|
150
|
+
run_cls_kwargs: Dict[str, Any],
|
|
151
|
+
str_kws: Dict[str, bool],
|
|
152
|
+
status: StatusRunAligned,
|
|
153
|
+
already_run_onnx: Set[int],
|
|
154
|
+
torch_names_to_onnx_names: Dict[str, str],
|
|
155
|
+
verbose: int,
|
|
156
|
+
exc: bool,
|
|
157
|
+
atol: float,
|
|
158
|
+
rtol: float,
|
|
159
|
+
reset_names: Set[str],
|
|
160
|
+
replay_configuration: Optional[ReplayConfiguration],
|
|
161
|
+
has_cuda: bool,
|
|
162
|
+
run_cls: type,
|
|
163
|
+
loop: Any,
|
|
164
|
+
run_onnx_with_torch_inputs: bool,
|
|
165
|
+
) -> Iterator[Optional[RunAlignedRecord]]:
|
|
166
|
+
|
|
167
|
+
if i_onnx in already_run_onnx:
|
|
168
|
+
yield None
|
|
169
|
+
node = onx.graph.node[i_onnx]
|
|
170
|
+
if verbose > 1:
|
|
171
|
+
print(
|
|
172
|
+
f"[run_aligned] run onx.graph.node[{i_onnx}]: "
|
|
173
|
+
f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
|
|
52
174
|
)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
f"
|
|
56
|
-
f"
|
|
57
|
-
f"node.meta={node.meta}"
|
|
175
|
+
elif verbose == 1:
|
|
176
|
+
loop.set_description(
|
|
177
|
+
f"ep {i_torch}/{len(ep_graph_nodes)} nx {i_onnx}/{len(onx.graph.node)} "
|
|
178
|
+
f"{status.to_str()}"
|
|
58
179
|
)
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
)
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
180
|
+
loop.update(min(1, 1 + i_torch + i_onnx))
|
|
181
|
+
|
|
182
|
+
ref = run_cls(node, **run_cls_kwargs)
|
|
183
|
+
# We need to clone because the runtime maybe using dlpack to create OrtValue
|
|
184
|
+
hidden_inputs = OnnxruntimeEvaluator._get_hidden_node_inputs(node)
|
|
185
|
+
all_inputs = [*node.input, *hidden_inputs] if hidden_inputs else node.input
|
|
186
|
+
feeds = (
|
|
187
|
+
{k: onnx_results[k].clone() for k in all_inputs if k}
|
|
188
|
+
if use_tensor
|
|
189
|
+
else {k: onnx_results[k].copy() for k in all_inputs if k}
|
|
190
|
+
)
|
|
191
|
+
assert "" not in feeds, f"Unexpected feeds={string_type(feeds, **str_kws)}"
|
|
192
|
+
if verbose > 1:
|
|
193
|
+
print(f"[run_aligned] feeds={string_type(feeds, **str_kws)}")
|
|
194
|
+
begin = time.perf_counter()
|
|
195
|
+
try:
|
|
196
|
+
res = ref.run(None, feeds) # type: ignore[attr-defined]
|
|
197
|
+
except Exception as e:
|
|
198
|
+
raise RuntimeError(
|
|
199
|
+
f"Unable to run node {node.op_type}, domain={node.domain} "
|
|
200
|
+
f"with inputs={node.input}, feeds={string_type(feeds, **str_kws)}"
|
|
201
|
+
) from e
|
|
202
|
+
duration = time.perf_counter() - begin
|
|
203
|
+
if verbose > 1:
|
|
204
|
+
print(f"[run_aligned] res={string_type(res, **str_kws)}")
|
|
205
|
+
assert (
|
|
206
|
+
not has_cuda
|
|
207
|
+
or not any(t is not None and t.is_cuda for t in feeds.values())
|
|
208
|
+
or any(t is not None and t.is_cuda for t in res)
|
|
209
|
+
or node.op_type in {"Shape", "Size"} # on CPU no matter what
|
|
210
|
+
or node.op_type
|
|
211
|
+
in {
|
|
212
|
+
"Add",
|
|
213
|
+
"Concat",
|
|
214
|
+
"Div",
|
|
215
|
+
"Gather",
|
|
216
|
+
"Mul",
|
|
217
|
+
"Range",
|
|
218
|
+
"Squeeze",
|
|
219
|
+
"Sub",
|
|
220
|
+
"Unsqueeze",
|
|
221
|
+
} # not sure, could be about shapes
|
|
222
|
+
), (
|
|
223
|
+
f"One input is on cuda but there is no float output on cuda, "
|
|
224
|
+
f"feeds={string_type(feeds, with_device=True, with_shape=True)}, "
|
|
225
|
+
f"res={string_type(res, with_device=True, with_shape=True)}, "
|
|
226
|
+
f"node is {pretty_onnx(node)}"
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
comment = None
|
|
230
|
+
cross = None
|
|
231
|
+
if run_onnx_with_torch_inputs:
|
|
232
|
+
# Let's run the operator with torch results if they are available
|
|
233
|
+
new_feeds, removed = make_torch_inputs(
|
|
234
|
+
node.input,
|
|
235
|
+
{
|
|
236
|
+
**{v: k for k, v in torch_names_to_onnx_names.items()},
|
|
237
|
+
**mapping_onnx_to_torch,
|
|
238
|
+
},
|
|
239
|
+
onnx_results,
|
|
240
|
+
torch_results,
|
|
241
|
+
submodel=None,
|
|
71
242
|
)
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
243
|
+
if not removed:
|
|
244
|
+
if verbose > 1:
|
|
245
|
+
print(
|
|
246
|
+
f"[run_aligned] feeds for second run="
|
|
247
|
+
f"{string_type(new_feeds, **str_kws)}"
|
|
248
|
+
)
|
|
249
|
+
cross = ref.run(None, new_feeds)
|
|
250
|
+
if verbose > 1:
|
|
251
|
+
print(f"[run_aligned] got for second run={string_type(cross, **str_kws)}")
|
|
252
|
+
# Gemm = torch.nn.function.linear, in that case, we just run it as well
|
|
253
|
+
to = mapping_onnx_to_torch.get(node.output[0], node.output[0])
|
|
254
|
+
if to in torch_results:
|
|
255
|
+
comment = _validation_nn_functional(node, new_feeds, [torch_results[to]])
|
|
256
|
+
elif verbose > 1:
|
|
257
|
+
print(f"[run_aligned] second run not possible because of missing {removed}")
|
|
258
|
+
|
|
259
|
+
if cross is None:
|
|
260
|
+
cross = [None for _ in res]
|
|
261
|
+
|
|
262
|
+
list_node_output = list(node.output)
|
|
263
|
+
node_output = [o for o in list_node_output if o]
|
|
264
|
+
for o, r, r2 in zip(node_output, res, cross):
|
|
265
|
+
if r is None or not o:
|
|
266
|
+
continue
|
|
267
|
+
tmp = _loop_cmp(
|
|
268
|
+
mapping_onnx_to_torch,
|
|
269
|
+
torch_results,
|
|
270
|
+
onnx_results,
|
|
271
|
+
o,
|
|
272
|
+
r,
|
|
273
|
+
r2,
|
|
274
|
+
verbose,
|
|
275
|
+
atol,
|
|
276
|
+
rtol,
|
|
277
|
+
i_torch,
|
|
278
|
+
i_onnx,
|
|
279
|
+
str_kws,
|
|
280
|
+
exc,
|
|
281
|
+
use_tensor,
|
|
79
282
|
)
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
283
|
+
if tmp is not None:
|
|
284
|
+
if tmp.ep_name in name_to_ep_node:
|
|
285
|
+
tmp.ep_id_node = name_to_ep_node[tmp.ep_name]
|
|
286
|
+
tmp.ep_target = str(ep_graph_nodes[tmp.ep_id_node].target)
|
|
287
|
+
tmp.ep_time_run = ep_durations[tmp.ep_id_node]
|
|
288
|
+
else:
|
|
289
|
+
tmp.ep_id_node = None
|
|
290
|
+
tmp.ep_target = None
|
|
291
|
+
tmp.ep_name = None
|
|
292
|
+
tmp.onnx_op_type = onx.graph.node[tmp.onnx_id_node].op_type
|
|
293
|
+
tmp.onnx_id_output = list_node_output.index(o)
|
|
294
|
+
tmp.onnx_time_run = duration
|
|
295
|
+
status.yielded_nodes += 1
|
|
296
|
+
if tmp.err_abs is not None:
|
|
297
|
+
status.update(tmp.err_abs)
|
|
298
|
+
tmp.comment = comment
|
|
299
|
+
yield tmp
|
|
87
300
|
|
|
301
|
+
# do we need to dump pieces if graph the user can replay?
|
|
302
|
+
if replay_configuration:
|
|
303
|
+
if replay_configuration.select(
|
|
304
|
+
name=tmp.onnx_name, op_type=tmp.onnx_op_type, err_abs=tmp.err_abs
|
|
305
|
+
):
|
|
306
|
+
replay_configuration.dump(
|
|
307
|
+
name=tmp.onnx_name,
|
|
308
|
+
onnx_id_node=tmp.onnx_id_node,
|
|
309
|
+
model=onx,
|
|
310
|
+
onnx_results=onnx_results,
|
|
311
|
+
torch_results=torch_results,
|
|
312
|
+
onnx_name_to_ep_name={
|
|
313
|
+
**{v: k for k, v in torch_names_to_onnx_names.items()},
|
|
314
|
+
**mapping_onnx_to_torch,
|
|
315
|
+
},
|
|
316
|
+
verbose=max(verbose - 1, 0),
|
|
317
|
+
)
|
|
318
|
+
status.last_replay = tmp.onnx_name
|
|
88
319
|
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
320
|
+
# reset_names: replaces onnx_results by torch_results to see
|
|
321
|
+
# if that fixes the discrepancies problem
|
|
322
|
+
if reset_names and tmp.ep_name in reset_names:
|
|
323
|
+
assert (
|
|
324
|
+
tmp.ep_name in torch_results
|
|
325
|
+
), f"name {tmp.ep_name!r} set to be reset is missing in torch_results."
|
|
326
|
+
assert (
|
|
327
|
+
tmp.onnx_name in onnx_results
|
|
328
|
+
), f"name {tmp.onnx_name!r} set to be reset is missing in onnx_results."
|
|
329
|
+
onnx_results[tmp.onnx_name] = torch_results[tmp.ep_name]
|
|
330
|
+
tmp = _loop_cmp(
|
|
331
|
+
mapping_onnx_to_torch,
|
|
332
|
+
torch_results,
|
|
333
|
+
onnx_results,
|
|
334
|
+
o,
|
|
335
|
+
torch_results[tmp.ep_name],
|
|
336
|
+
None,
|
|
337
|
+
verbose,
|
|
338
|
+
atol,
|
|
339
|
+
rtol,
|
|
340
|
+
i_torch,
|
|
341
|
+
i_onnx,
|
|
342
|
+
str_kws,
|
|
343
|
+
exc,
|
|
344
|
+
use_tensor,
|
|
345
|
+
)
|
|
346
|
+
assert tmp.err_abs == 0, f"Reset did not happen, tmp={tmp}"
|
|
347
|
+
if tmp is not None:
|
|
348
|
+
tmp.onnx_op_type = "reset"
|
|
349
|
+
tmp.onnx_id_output = list_node_output.index(o)
|
|
350
|
+
status.yielded_nodes += 1
|
|
351
|
+
yield tmp
|
|
352
|
+
already_run_onnx.add(i_onnx)
|
|
94
353
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
:
|
|
98
|
-
:
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
354
|
+
|
|
355
|
+
def _preparation_with_fx_graph(
|
|
356
|
+
ep_graph_nodes: List[torch.fx.Node],
|
|
357
|
+
name_to_ep_node: Dict[str, int],
|
|
358
|
+
torch_input_names: List[str],
|
|
359
|
+
onx: onnx.ModelProto,
|
|
360
|
+
torch_names_to_onnx_names: Dict[str, str],
|
|
361
|
+
skip_mapping_torch_onnx,
|
|
362
|
+
torch_results: Dict[str, torch.Tensor],
|
|
363
|
+
placeholders: Dict[str, torch.Tensor],
|
|
364
|
+
placeholders_to_state_dict: Dict[str, str],
|
|
365
|
+
ep_state_dict: Dict[str, torch.Tensor],
|
|
366
|
+
positions: Dict[str, Dict[str, int]],
|
|
367
|
+
) -> List[str]:
|
|
368
|
+
torch_output_names = None
|
|
369
|
+
for i, node in enumerate(ep_graph_nodes):
|
|
370
|
+
if isinstance(node.name, str):
|
|
371
|
+
positions[node.name] = dict(fx=i)
|
|
372
|
+
else:
|
|
373
|
+
for n in node.name:
|
|
374
|
+
positions[n] = dict(fx=i)
|
|
375
|
+
if node.op == "placeholder":
|
|
376
|
+
if node.name in placeholders_to_state_dict:
|
|
377
|
+
# This a weight.
|
|
378
|
+
placeholders[node.name] = ep_state_dict[placeholders_to_state_dict[node.name]]
|
|
379
|
+
torch_results[node.name] = placeholders[node.name]
|
|
380
|
+
assert isinstance(torch_results[node.name], torch.Tensor), (
|
|
381
|
+
f"torch_results[{node.name}] not a tensor but "
|
|
382
|
+
f"{type(torch_results[node.name])}"
|
|
383
|
+
)
|
|
384
|
+
else:
|
|
385
|
+
# This is an input
|
|
386
|
+
assert len(torch_input_names) < len(onx.graph.input), (
|
|
387
|
+
f"torch_input_names={torch_input_names!r}, "
|
|
388
|
+
f"onnx_input_names={[n.name for n in onx.graph.input]}, "
|
|
389
|
+
f"node.name={node.name!r} cannot be an input, "
|
|
390
|
+
f"placeholders_to_state_dict={sorted(placeholders_to_state_dict)}"
|
|
391
|
+
)
|
|
392
|
+
assert node.name not in skip_mapping_torch_onnx, (
|
|
393
|
+
f"{node.name!r} is ambiguous, cannot be mapped due to "
|
|
394
|
+
f"{skip_mapping_torch_onnx}"
|
|
395
|
+
)
|
|
396
|
+
torch_names_to_onnx_names[node.name] = onx.graph.input[
|
|
397
|
+
len(torch_input_names)
|
|
398
|
+
].name
|
|
399
|
+
torch_input_names.append(node.name)
|
|
400
|
+
elif node.op == "output":
|
|
401
|
+
torch_output_names = [n.name for n in node.args[0]]
|
|
402
|
+
assert isinstance(node.name, str), (
|
|
403
|
+
f"Unexpected type {type(node.name)} for node={node} (target={node.target}), "
|
|
404
|
+
f"args={node.args}"
|
|
104
405
|
)
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
406
|
+
name_to_ep_node[node.name] = i
|
|
407
|
+
assert torch_output_names is not None, "No output node ws found the graph."
|
|
408
|
+
return torch_output_names
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def _preparation_with_onnx_model(
|
|
412
|
+
default_device,
|
|
413
|
+
use_tensor: bool,
|
|
414
|
+
onx: onnx.ModelProto,
|
|
415
|
+
already_yielded: Dict[str, Any],
|
|
416
|
+
str_kws: Dict[str, bool],
|
|
417
|
+
positions: Dict[str, Dict[str, int]],
|
|
418
|
+
torch_names_to_onnx_names: Dict[str, str],
|
|
419
|
+
torch_output_names: List[str],
|
|
420
|
+
torch_results: Dict[str, torch.Tensor],
|
|
421
|
+
skip_mapping_torch_onnx: Set[str],
|
|
422
|
+
verbose: int,
|
|
423
|
+
args: Sequence[Any],
|
|
424
|
+
kwargs: Dict[str, Any],
|
|
425
|
+
) -> Tuple[Dict[str, str], Dict[str, torch.Tensor], float, float, List[RunAlignedRecord]]:
|
|
426
|
+
for inp in onx.graph.input:
|
|
427
|
+
n = inp.name
|
|
428
|
+
if n in positions:
|
|
429
|
+
positions[n]["onnx"] = -1
|
|
430
|
+
else:
|
|
431
|
+
positions[n] = dict(onnx=-1)
|
|
432
|
+
for inp in onx.graph.initializer:
|
|
433
|
+
n = inp.name
|
|
434
|
+
if n in positions:
|
|
435
|
+
positions[n]["onnx"] = -1
|
|
436
|
+
else:
|
|
437
|
+
positions[n] = dict(onnx=-1)
|
|
438
|
+
for i, node in enumerate(onx.graph.node):
|
|
439
|
+
for n in node.output:
|
|
440
|
+
if n in positions:
|
|
441
|
+
positions[n]["onnx"] = i
|
|
442
|
+
else:
|
|
443
|
+
positions[n] = dict(onnx=i)
|
|
444
|
+
|
|
445
|
+
onnx_outputs_names = [o.name for o in onx.graph.output]
|
|
446
|
+
assert torch_output_names is not None and len(torch_output_names) == len(
|
|
447
|
+
onnx_outputs_names
|
|
448
|
+
), (
|
|
449
|
+
f"Unexpected number of outputs, torch_output_names={torch_output_names}, "
|
|
450
|
+
f"onnx_outputs_names={onnx_outputs_names}"
|
|
113
451
|
)
|
|
452
|
+
mapping_onnx_to_torch = dict(zip(onnx_outputs_names, torch_output_names))
|
|
114
453
|
|
|
454
|
+
onnx_args = list(args) if args else []
|
|
455
|
+
if kwargs:
|
|
456
|
+
onnx_args.extend(flatten_object(kwargs, drop_keys=True))
|
|
457
|
+
if verbose:
|
|
458
|
+
print(f"[run_aligned] args: {string_type(args, **str_kws)}")
|
|
459
|
+
print(f"[run_aligned] kwargs: {string_type(kwargs, **str_kws)}")
|
|
460
|
+
print(f"[run_aligned] onnx: {string_type(onnx_args, **str_kws)}")
|
|
461
|
+
print(f"[run_aligned] nx: walks through {len(onx.graph.input)} onnx inputs")
|
|
462
|
+
onnx_results: Dict[str, Any] = {}
|
|
463
|
+
for inp, v in zip(onx.graph.input, onnx_args):
|
|
464
|
+
onnx_results[inp.name] = _check_tensor_(
|
|
465
|
+
use_tensor, inp.name, v if use_tensor else to_numpy(v)
|
|
466
|
+
)
|
|
467
|
+
if verbose:
|
|
468
|
+
print(f"[run_aligned-nx] +inp: {inp.name}: {string_type(v, **str_kws)}")
|
|
115
469
|
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
470
|
+
# alias for initializers
|
|
471
|
+
skip_onnx_name = set()
|
|
472
|
+
init_aliases: Dict[str, str] = {}
|
|
473
|
+
for init in onx.graph.initializer:
|
|
474
|
+
new_names = {
|
|
475
|
+
n
|
|
476
|
+
for n in [
|
|
477
|
+
f"p_{init.name.replace('.', '_')}",
|
|
478
|
+
f"p_{init.name.split('::')[0].split('--')[-1].replace('.', '_')}",
|
|
479
|
+
f"{init.name.split('::')[0].split('--')[-1].replace('.', '_')}",
|
|
480
|
+
]
|
|
481
|
+
if n != init.name
|
|
482
|
+
}
|
|
483
|
+
drop = False
|
|
484
|
+
for new_name in new_names:
|
|
485
|
+
if new_name in skip_onnx_name:
|
|
486
|
+
drop = True
|
|
487
|
+
break
|
|
488
|
+
if drop:
|
|
489
|
+
skip_onnx_name |= new_names | {init.name}
|
|
490
|
+
for new_name in new_names:
|
|
491
|
+
if new_names in init_aliases:
|
|
492
|
+
del init_aliases[new_name]
|
|
493
|
+
else:
|
|
494
|
+
for new_name in new_names:
|
|
495
|
+
init_aliases[new_name] = init.name
|
|
496
|
+
rev_init_aliases: Dict[str, Set[str]] = {}
|
|
497
|
+
for k, v in init_aliases.items():
|
|
498
|
+
if v in rev_init_aliases:
|
|
499
|
+
rev_init_aliases[v].add(k)
|
|
500
|
+
else:
|
|
501
|
+
rev_init_aliases[v] = {k}
|
|
138
502
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
503
|
+
# initializers
|
|
504
|
+
if verbose:
|
|
505
|
+
print(f"[run_aligned] nx: handles {len(onx.graph.initializer)} initializers from onnx")
|
|
506
|
+
memory_cpu = 0
|
|
507
|
+
memory_cuda = 0
|
|
508
|
+
records_to_yield = []
|
|
509
|
+
for init in onx.graph.initializer: # type: ignore
|
|
510
|
+
t = None
|
|
511
|
+
if init.name in torch_results:
|
|
512
|
+
if init.name not in skip_mapping_torch_onnx:
|
|
513
|
+
t = torch_results[init.name]
|
|
514
|
+
torch_names_to_onnx_names[init.name] = init.name
|
|
515
|
+
elif init.name not in skip_onnx_name and init.name in rev_init_aliases:
|
|
516
|
+
new_names = [
|
|
517
|
+
k
|
|
518
|
+
for k in rev_init_aliases[init.name]
|
|
519
|
+
if k in torch_results and k not in skip_mapping_torch_onnx
|
|
520
|
+
]
|
|
521
|
+
if new_names and len(new_names) == 1:
|
|
522
|
+
new_name = new_names[0] # type: ignore[assignment, index]
|
|
523
|
+
t = torch_results[new_name]
|
|
524
|
+
if (
|
|
525
|
+
len(set(t.shape)) == len(t.shape) # not repeated dimension
|
|
526
|
+
and t.shape == tuple(init.dims)
|
|
527
|
+
and torch_dtype_to_onnx_dtype(t.dtype) == init.data_type
|
|
528
|
+
):
|
|
529
|
+
torch_names_to_onnx_names[new_name] = init.name
|
|
530
|
+
else:
|
|
531
|
+
t = None
|
|
532
|
+
|
|
533
|
+
# We should check tensors and proto are the same.
|
|
534
|
+
if t is None:
|
|
535
|
+
t = to_tensor(init)
|
|
536
|
+
if default_device and t.numel() >= 1024:
|
|
537
|
+
# Let's force its way to cuda (should check the device has well).
|
|
538
|
+
t = t.to(default_device)
|
|
539
|
+
records_to_yield.append(
|
|
540
|
+
RunAlignedRecord(
|
|
541
|
+
onnx_id_node=-1,
|
|
542
|
+
onnx_name=init.name,
|
|
543
|
+
onnx_op_type="initializer",
|
|
544
|
+
onnx_shape_type=string_type(t, **str_kws),
|
|
545
|
+
).check(already_yielded)
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
size = t.element_size() * t.numel()
|
|
549
|
+
if t.is_cuda:
|
|
550
|
+
memory_cuda += size
|
|
551
|
+
else:
|
|
552
|
+
memory_cpu += size
|
|
553
|
+
if init.name not in onnx_results:
|
|
554
|
+
# otherwise, it is an input with a default value
|
|
555
|
+
onnx_results[init.name] = _check_tensor_(use_tensor, init.name, t, flip_type=True)
|
|
556
|
+
return mapping_onnx_to_torch, onnx_results, memory_cpu, memory_cuda, records_to_yield
|
|
146
557
|
|
|
147
558
|
|
|
148
559
|
def run_aligned(
|
|
149
560
|
ep: torch.export.ExportedProgram,
|
|
150
561
|
onx: Union[onnx.ModelProto, onnx.FunctionProto],
|
|
151
|
-
|
|
152
|
-
|
|
562
|
+
run_cls: Callable[
|
|
563
|
+
[
|
|
564
|
+
Union[
|
|
565
|
+
onnx.ModelProto,
|
|
566
|
+
onnx.FunctionProto,
|
|
567
|
+
onnx.GraphProto,
|
|
568
|
+
onnx.NodeProto,
|
|
569
|
+
]
|
|
570
|
+
],
|
|
571
|
+
List[Union[np.ndarray, torch.Tensor]],
|
|
572
|
+
],
|
|
573
|
+
args: Optional[Tuple[torch.Tensor, ...]] = None,
|
|
153
574
|
kwargs: Optional[Dict[str, Any]] = None,
|
|
575
|
+
use_tensor: bool = False,
|
|
576
|
+
atol: Optional[float] = None,
|
|
577
|
+
rtol: Optional[float] = None,
|
|
154
578
|
verbose: int = 0,
|
|
155
|
-
|
|
579
|
+
exc: bool = True,
|
|
580
|
+
reset_names: Optional[List[str]] = None,
|
|
581
|
+
replay_configuration: Optional[ReplayConfiguration] = None,
|
|
582
|
+
run_onnx_with_torch_inputs: bool = False,
|
|
583
|
+
) -> Iterator[RunAlignedRecord]:
|
|
156
584
|
"""
|
|
157
585
|
Runs in parallel both the exported program
|
|
158
586
|
and the onnx proto and looks for discrepancies.
|
|
@@ -162,11 +590,25 @@ def run_aligned(
|
|
|
162
590
|
|
|
163
591
|
:param ep: exported program
|
|
164
592
|
:param onx: model or function proto
|
|
593
|
+
:param run_cls: defines the runtime to use for this task
|
|
165
594
|
:param args: input args
|
|
166
|
-
:param check_conversion_cls: defines the runtime to use for this task
|
|
167
595
|
:param kwargs: input kwargs
|
|
596
|
+
:param use_tensor: use torch tensors instead of numpy arrays
|
|
597
|
+
for the onnx runtime
|
|
598
|
+
:param atol: absolute tolerance
|
|
599
|
+
:param rtol: relative tolerance
|
|
168
600
|
:param verbose: verbosity level
|
|
169
|
-
:
|
|
601
|
+
:param exc: stops if an exception
|
|
602
|
+
:param reset_names: list of names, the onnx execution takes the torch outputs instead
|
|
603
|
+
of its own result if the names falls into that set
|
|
604
|
+
:param replay_configuration: configuration to let the user dump any problematic
|
|
605
|
+
piece of the onnx graph he wants to replay in order to investigate later,
|
|
606
|
+
see :class: `ReplayConfiguration
|
|
607
|
+
<onnx_diagnostic.torch_onnx.sbs.ReplayConfiguration>`
|
|
608
|
+
:param run_onnx_with_torch_inputs: run an onnx operator with torch results
|
|
609
|
+
if they available
|
|
610
|
+
:return: a list of :class:`RunAlignedRecord
|
|
611
|
+
<onnx_diagnostic.torch_onnx.sbs_dataclasses.RunAlignedRecord>`
|
|
170
612
|
|
|
171
613
|
Example:
|
|
172
614
|
|
|
@@ -174,11 +616,10 @@ def run_aligned(
|
|
|
174
616
|
:showcode:
|
|
175
617
|
:warningout: UserWarning
|
|
176
618
|
|
|
177
|
-
import pprint
|
|
178
619
|
import pandas
|
|
179
620
|
import torch
|
|
180
621
|
from onnx_diagnostic.reference import (
|
|
181
|
-
# This can be
|
|
622
|
+
# This can be replaced by any runtime taking NodeProto as an input.
|
|
182
623
|
ExtendedReferenceEvaluator as ReferenceEvaluator,
|
|
183
624
|
)
|
|
184
625
|
from onnx_diagnostic.torch_onnx.sbs import run_aligned
|
|
@@ -193,13 +634,6 @@ def run_aligned(
|
|
|
193
634
|
return ru
|
|
194
635
|
|
|
195
636
|
|
|
196
|
-
def post_process(obs):
|
|
197
|
-
dobs = dict(zip(["ep_id_node", "onnx_id_node", "ep_name", "onnx_name"], obs))
|
|
198
|
-
dobs["err_abs"] = obs[-1]["abs"]
|
|
199
|
-
dobs["err_rel"] = obs[-1]["rel"]
|
|
200
|
-
return dobs
|
|
201
|
-
|
|
202
|
-
|
|
203
637
|
x = torch.randn((5, 4))
|
|
204
638
|
Model()(x) # to make sure the model is running
|
|
205
639
|
ep = torch.export.export(
|
|
@@ -209,131 +643,341 @@ def run_aligned(
|
|
|
209
643
|
Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},)
|
|
210
644
|
).model_proto
|
|
211
645
|
results = list(
|
|
212
|
-
|
|
213
|
-
post_process,
|
|
214
|
-
run_aligned(
|
|
215
|
-
ep,
|
|
216
|
-
onx,
|
|
217
|
-
(x,),
|
|
218
|
-
check_conversion_cls=dict(cls=ReferenceEvaluator, atol=1e-5, rtol=1e-5),
|
|
219
|
-
verbose=1,
|
|
220
|
-
),
|
|
221
|
-
),
|
|
646
|
+
run_aligned(ep, onx, ReferenceEvaluator, (x,), atol=1e-5, rtol=1e-5, verbose=1)
|
|
222
647
|
)
|
|
223
648
|
print("------------")
|
|
224
649
|
print("final results")
|
|
225
650
|
df = pandas.DataFrame(results)
|
|
651
|
+
df = df.apply(lambda col: col.fillna("") if col.dtype == "object" else col)
|
|
226
652
|
print(df)
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
653
|
+
|
|
654
|
+
This example uses :class:`onnx.reference.ReferenceEvaluator` to run the onnx model
|
|
655
|
+
but onnxruntime can also be used through
|
|
656
|
+
:class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`.
|
|
657
|
+
It relies on :epkg:`onnxruntime` and selects CPU or CUDA depending
|
|
658
|
+
on the device where the inputs are located.
|
|
659
|
+
|
|
660
|
+
The :class:`torch.export.ExportedProgram` can be saved on disk
|
|
661
|
+
with ``ep.save("<filename>.pt")`` and restored with
|
|
662
|
+
``torch.export.load("<filename>.pt")``. That leeds the input to save.
|
|
663
|
+
We can decouple the export and the alignment.
|
|
664
|
+
|
|
665
|
+
.. runpython::
|
|
666
|
+
:showcode:
|
|
667
|
+
:warningout: UserWarning
|
|
668
|
+
|
|
669
|
+
import onnx
|
|
670
|
+
import torch
|
|
671
|
+
from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
class Model(torch.nn.Module):
|
|
675
|
+
def forward(self, x):
|
|
676
|
+
ry = x.abs()
|
|
677
|
+
rz = ry.exp()
|
|
678
|
+
rw = rz + 1
|
|
679
|
+
ru = rw.log() + rw
|
|
680
|
+
return ru
|
|
681
|
+
|
|
682
|
+
|
|
683
|
+
x = torch.randn((5, 4))
|
|
684
|
+
dynamic_shapes = ({0: "batch"},)
|
|
685
|
+
Model()(x) # to make sure the model is running
|
|
686
|
+
ep = torch.export.export(
|
|
687
|
+
Model(), (x,), dynamic_shapes=use_dyn_not_str(dynamic_shapes)
|
|
234
688
|
)
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
689
|
+
onx = torch.onnx.export(
|
|
690
|
+
Model(), (x,), dynamic_shapes=dynamic_shapes
|
|
691
|
+
).model_proto
|
|
238
692
|
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
if isinstance(node.name, str):
|
|
243
|
-
positions[node.name] = dict(fx=i)
|
|
244
|
-
else:
|
|
245
|
-
for n in node.name:
|
|
246
|
-
positions[n] = dict(fx=i)
|
|
693
|
+
torch.export.save(ep, "test_doc_sbs_example.pt2")
|
|
694
|
+
onnx.save(onx, "test_doc_sbs_example.onnx")
|
|
695
|
+
torch.save((x,), "test_doc_sbs_example.pt")
|
|
247
696
|
|
|
248
|
-
|
|
249
|
-
for n in node.output:
|
|
250
|
-
if n in positions:
|
|
251
|
-
positions[n]["onnx"] = i
|
|
252
|
-
else:
|
|
253
|
-
positions[n] = dict(onnx=i)
|
|
697
|
+
Then we can restore all of them and run it.
|
|
254
698
|
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
699
|
+
.. runpython::
|
|
700
|
+
:showcode:
|
|
701
|
+
:warningout: UserWarning
|
|
702
|
+
|
|
703
|
+
import pandas
|
|
704
|
+
import onnx
|
|
705
|
+
import torch
|
|
706
|
+
from onnx_diagnostic.torch_onnx.sbs import run_aligned
|
|
707
|
+
from onnx_diagnostic.reference import OnnxruntimeEvaluator
|
|
708
|
+
|
|
709
|
+
|
|
710
|
+
ep = torch.export.load("test_doc_sbs_example.pt2")
|
|
711
|
+
onx = onnx.load("test_doc_sbs_example.onnx")
|
|
712
|
+
inputs = torch.load("test_doc_sbs_example.pt")
|
|
713
|
+
|
|
714
|
+
|
|
715
|
+
results = list(
|
|
716
|
+
run_aligned(
|
|
717
|
+
ep,
|
|
718
|
+
onx,
|
|
719
|
+
OnnxruntimeEvaluator,
|
|
720
|
+
inputs,
|
|
721
|
+
atol=1e-5,
|
|
722
|
+
rtol=1e-5,
|
|
723
|
+
verbose=1,
|
|
724
|
+
use_tensor=True,
|
|
725
|
+
)
|
|
265
726
|
)
|
|
266
|
-
|
|
727
|
+
print("------------")
|
|
728
|
+
print("final results")
|
|
729
|
+
df = pandas.DataFrame(results)
|
|
730
|
+
df = df.apply(lambda col: col.fillna("") if col.dtype == "object" else col)
|
|
731
|
+
print(df)
|
|
732
|
+
|
|
733
|
+
A command line can also be run:
|
|
734
|
+
|
|
735
|
+
.. code-block:: bash
|
|
267
736
|
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
737
|
+
python -m onnx_diagnostic sbs -i <tensors>.input.pt \\
|
|
738
|
+
--ep <exported_program>.pt2 \\
|
|
739
|
+
-m <model>.onnx \\
|
|
740
|
+
-o results.xlsx \\
|
|
741
|
+
-v 1 --atol=0.1 --rtol=1
|
|
742
|
+
"""
|
|
743
|
+
assert callable(run_cls), f"run_cls={run_cls} not a callable"
|
|
744
|
+
already_yielded = {}
|
|
745
|
+
reset_names = set(reset_names) if reset_names else set()
|
|
746
|
+
str_kws = dict(with_shape=True, with_device=True)
|
|
747
|
+
has_cuda = any(
|
|
748
|
+
(isinstance(t, torch.Tensor) and t.is_cuda)
|
|
749
|
+
for t in flatten_object([args, kwargs], drop_keys=True)
|
|
750
|
+
)
|
|
751
|
+
default_device = None
|
|
752
|
+
if has_cuda:
|
|
753
|
+
for t in flatten_object([args, kwargs], drop_keys=True):
|
|
754
|
+
if t is not None and t.is_cuda:
|
|
755
|
+
default_device = t.device
|
|
756
|
+
break
|
|
757
|
+
run_cls_kwargs = {
|
|
758
|
+
"ir_version": onx.ir_version,
|
|
759
|
+
"opsets": {d.domain: d.version for d in onx.opset_import},
|
|
760
|
+
"verbose": max(verbose - 1, 0),
|
|
761
|
+
"providers": (
|
|
762
|
+
["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
763
|
+
if has_cuda
|
|
764
|
+
else ["CPUExecutionProvider"]
|
|
765
|
+
),
|
|
272
766
|
}
|
|
273
|
-
|
|
767
|
+
run_cls_kwargs = {
|
|
768
|
+
k: v
|
|
769
|
+
for k, v in run_cls_kwargs.items()
|
|
770
|
+
if k in set(inspect.signature(run_cls).parameters)
|
|
771
|
+
}
|
|
772
|
+
if verbose:
|
|
773
|
+
print(f"[run_aligned] run_cls={run_cls}")
|
|
774
|
+
print(f"[run_aligned] run_cls_kwargs={run_cls_kwargs}")
|
|
775
|
+
if replay_configuration:
|
|
776
|
+
print(f"[run_aligned] replay={replay_configuration}")
|
|
777
|
+
|
|
778
|
+
# preparation with ep.graph.nodes
|
|
779
|
+
ep_state_dict = {**ep.state_dict, **dict(ep.named_buffers(), **ep.tensor_constants)}
|
|
780
|
+
placeholders_to_state_dict = {
|
|
781
|
+
**{f"p_{name.replace('.', '_').lower()}": name for name in ep.state_dict},
|
|
782
|
+
**{f"b_{name.replace('.', '_').lower()}": name for name, _ in ep.named_buffers()},
|
|
783
|
+
**{f"c_{name.replace('.', '_').lower()}": name for name in ep.tensor_constants},
|
|
784
|
+
}
|
|
785
|
+
skip_mapping_torch_onnx = _duplicated_values(placeholders_to_state_dict)
|
|
786
|
+
placeholders = {}
|
|
787
|
+
if verbose:
|
|
788
|
+
print(f"[run_aligned] ep: model has {len(ep_state_dict)} torch constants or weights.")
|
|
789
|
+
|
|
790
|
+
if verbose:
|
|
791
|
+
print(f"[run_aligned] ep: walks through {len(ep.graph.nodes)} nodes from torch")
|
|
792
|
+
|
|
793
|
+
# dictionary mapping result names and their position in both graphs.
|
|
794
|
+
positions: Dict[str, Dict[str, int]] = {}
|
|
795
|
+
ep_graph_nodes = list(ep.graph.nodes)
|
|
796
|
+
torch_results: Dict[str, Any] = {}
|
|
274
797
|
torch_output_names = None
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
|
|
798
|
+
torch_input_names: List[str] = []
|
|
799
|
+
name_to_ep_node = {}
|
|
800
|
+
torch_names_to_onnx_names = {}
|
|
801
|
+
torch_output_names = _preparation_with_fx_graph(
|
|
802
|
+
ep_graph_nodes,
|
|
803
|
+
name_to_ep_node,
|
|
804
|
+
torch_input_names,
|
|
805
|
+
onx,
|
|
806
|
+
torch_names_to_onnx_names,
|
|
807
|
+
skip_mapping_torch_onnx,
|
|
808
|
+
torch_results,
|
|
809
|
+
placeholders,
|
|
810
|
+
placeholders_to_state_dict,
|
|
811
|
+
ep_state_dict,
|
|
812
|
+
positions,
|
|
284
813
|
)
|
|
285
|
-
mapping_onnx_to_torch = dict(zip(onnx_outputs_names, torch_output_names))
|
|
286
814
|
|
|
815
|
+
# prepration for onnx
|
|
287
816
|
if verbose:
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
)
|
|
293
|
-
for k, v in onnx_results.items():
|
|
294
|
-
print(
|
|
295
|
-
f"[run_aligned] +onnx-init: {k}: "
|
|
296
|
-
f"{string_type(v, with_shape=True, with_min_max=True)}"
|
|
297
|
-
)
|
|
817
|
+
print(f"[run_aligned] ep: found {len(torch_results)} torch constants or weights.")
|
|
818
|
+
print(f"[run_aligned] ep: found inputs {torch_input_names}")
|
|
819
|
+
print(f"[run_aligned] ep: found outputs {torch_output_names}")
|
|
820
|
+
print(f"[run_aligned] nx: walks through {len(onx.graph.node)} nodes from onnx")
|
|
298
821
|
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
822
|
+
mapping_onnx_to_torch, onnx_results, memory_cpu, memory_cuda, records_to_yield = (
|
|
823
|
+
_preparation_with_onnx_model(
|
|
824
|
+
default_device,
|
|
825
|
+
use_tensor,
|
|
826
|
+
onx,
|
|
827
|
+
already_yielded,
|
|
828
|
+
str_kws,
|
|
829
|
+
positions,
|
|
830
|
+
torch_names_to_onnx_names,
|
|
831
|
+
torch_output_names,
|
|
832
|
+
torch_results,
|
|
833
|
+
skip_mapping_torch_onnx,
|
|
834
|
+
verbose,
|
|
835
|
+
args,
|
|
836
|
+
kwargs,
|
|
837
|
+
)
|
|
838
|
+
)
|
|
839
|
+
for record in records_to_yield:
|
|
840
|
+
yield record
|
|
306
841
|
|
|
307
|
-
|
|
308
|
-
|
|
842
|
+
if verbose:
|
|
843
|
+
print(f"[run_aligned] nx: handled {len(onnx_results)} initializers from onnx")
|
|
844
|
+
print(f"[run_aligned] nx: memory cpu {memory_cpu / 2**20:.3f} Mb")
|
|
845
|
+
print(f"[run_aligned] nx: memory cuda {memory_cuda / 2**20:.3f} Mb")
|
|
846
|
+
print(f"[run_aligned] nx: {len(onnx_results)} constants")
|
|
847
|
+
print(f"[run_aligned] nx: {len(onx.graph.input)} inputs")
|
|
848
|
+
print(f"[run_aligned] nx: {len(onx.graph.output)} outputs")
|
|
849
|
+
print(f"[run_aligned] bo: {len(mapping_onnx_to_torch)} outputs")
|
|
850
|
+
print(f"[run_aligned] run_cls_kwargs={run_cls_kwargs}")
|
|
851
|
+
if verbose > 1:
|
|
852
|
+
for k, v in torch_results.items():
|
|
853
|
+
print(f"[run_aligned-ep] +cst: {k}: {string_type(v, **str_kws)}")
|
|
854
|
+
for k, v in onnx_results.items():
|
|
855
|
+
print(f"[run_aligned-nx] +ini: {k}: {string_type(v, **str_kws)}")
|
|
856
|
+
|
|
857
|
+
# starts the side-by-side
|
|
858
|
+
if verbose:
|
|
859
|
+
print(
|
|
860
|
+
f"[run_aligned] ep: starts side-by-side with {len(ep_graph_nodes)} "
|
|
861
|
+
f"fx nodes and {len(onx.graph.node)} onnx nodes"
|
|
862
|
+
)
|
|
863
|
+
if verbose == 1:
|
|
864
|
+
import tqdm
|
|
865
|
+
|
|
866
|
+
loop = tqdm.tqdm(total=len(ep_graph_nodes) + len(onx.graph.node))
|
|
867
|
+
else:
|
|
868
|
+
loop = None
|
|
869
|
+
|
|
870
|
+
already_run: Set[int] = set()
|
|
871
|
+
ep_durations = {}
|
|
872
|
+
status = StatusRunAligned()
|
|
873
|
+
last_position = 0
|
|
874
|
+
for i_torch, node in enumerate(ep_graph_nodes):
|
|
875
|
+
if verbose > 1:
|
|
309
876
|
if node.op == "call_function":
|
|
310
877
|
print(
|
|
311
|
-
f"[run_aligned] run ep.graph.nodes[{
|
|
878
|
+
f"[run_aligned] run ep.graph.nodes[{i_torch}]: "
|
|
312
879
|
f"{node.op}[{node.target}] -> {node.name!r}"
|
|
313
880
|
)
|
|
314
881
|
else:
|
|
315
|
-
print(
|
|
882
|
+
print(
|
|
883
|
+
f"[run_aligned] run ep.graph.nodes[{i_torch}]: {node.op} -> {node.name!r}"
|
|
884
|
+
)
|
|
885
|
+
elif verbose == 1:
|
|
886
|
+
loop.set_description(
|
|
887
|
+
f"ep {i_torch}/{len(ep_graph_nodes)} nx {last_position}/{len(onx.graph.node)} "
|
|
888
|
+
f"{status.to_str()}"
|
|
889
|
+
)
|
|
890
|
+
loop.update(min(1, 1 + i_torch + last_position))
|
|
316
891
|
|
|
317
892
|
if node.op == "placeholder":
|
|
318
|
-
|
|
319
|
-
|
|
320
|
-
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
|
|
893
|
+
is_input = node.name not in placeholders
|
|
894
|
+
if is_input:
|
|
895
|
+
torch_results[node.name] = (
|
|
896
|
+
onnx_results[torch_names_to_onnx_names[node.name]]
|
|
897
|
+
if use_tensor
|
|
898
|
+
else from_numpy(onnx_results[torch_names_to_onnx_names[node.name]])
|
|
899
|
+
)
|
|
900
|
+
assert isinstance(torch_results[node.name], torch.Tensor), (
|
|
901
|
+
f"torch_results[{node.name}] not a tensor but "
|
|
902
|
+
f"{type(torch_results[node.name])}, use_tensor={use_tensor}"
|
|
903
|
+
)
|
|
904
|
+
t = torch_results[node.name]
|
|
905
|
+
if verbose > 1:
|
|
906
|
+
print(f"[run_aligned-ep] =ags: {node.name}={string_type(t, **str_kws)}")
|
|
907
|
+
# Otherwise, it is an input.
|
|
908
|
+
record = RunAlignedRecord(
|
|
909
|
+
ep_id_node=i_torch,
|
|
910
|
+
onnx_id_node=-1,
|
|
911
|
+
ep_name=node.name,
|
|
912
|
+
onnx_name=torch_names_to_onnx_names[node.name],
|
|
913
|
+
ep_target="input",
|
|
914
|
+
onnx_op_type="input",
|
|
915
|
+
ep_shape_type=string_type(t, **str_kws),
|
|
916
|
+
onnx_shape_type=string_type(
|
|
917
|
+
onnx_results[torch_names_to_onnx_names[node.name]], **str_kws
|
|
918
|
+
),
|
|
919
|
+
)
|
|
920
|
+
yield record.check(already_yielded)
|
|
921
|
+
else:
|
|
922
|
+
assert node.name in placeholders_to_state_dict, (
|
|
923
|
+
f"Unable to find placeholder {node.name!r} (node.op={node.op!r}), "
|
|
924
|
+
f"existing: {sorted(placeholders_to_state_dict)}"
|
|
925
|
+
)
|
|
926
|
+
assert node.name in torch_results, (
|
|
927
|
+
f"placeholder {node.name!r} (node.op={node.op!r}), "
|
|
928
|
+
f"should have been added to torch_results: {sorted(torch_results)}"
|
|
929
|
+
)
|
|
930
|
+
t = torch_results[node.name]
|
|
931
|
+
if (
|
|
932
|
+
node.name in torch_names_to_onnx_names
|
|
933
|
+
and node.name not in skip_mapping_torch_onnx
|
|
934
|
+
):
|
|
935
|
+
if verbose > 1:
|
|
936
|
+
print(
|
|
937
|
+
f"[run_aligned-ep] =plh: "
|
|
938
|
+
f"{node.name}={string_type(t, **str_kws)}"
|
|
939
|
+
)
|
|
940
|
+
record = RunAlignedRecord(
|
|
941
|
+
ep_id_node=i_torch,
|
|
942
|
+
onnx_id_node=-1,
|
|
943
|
+
ep_name=node.name,
|
|
944
|
+
onnx_name=torch_names_to_onnx_names[node.name],
|
|
945
|
+
ep_target="placeholder",
|
|
946
|
+
onnx_op_type="initializer",
|
|
947
|
+
ep_shape_type=string_type(t, **str_kws),
|
|
948
|
+
onnx_shape_type=string_type(
|
|
949
|
+
onnx_results[torch_names_to_onnx_names[node.name]], **str_kws
|
|
950
|
+
),
|
|
325
951
|
)
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
952
|
+
if not is_input:
|
|
953
|
+
record.set_diff(
|
|
954
|
+
max_diff(
|
|
955
|
+
t,
|
|
956
|
+
onnx_results[torch_names_to_onnx_names[node.name]],
|
|
957
|
+
hist=[0.1, 0.01],
|
|
958
|
+
)
|
|
959
|
+
)
|
|
960
|
+
yield record.check(already_yielded)
|
|
961
|
+
else:
|
|
962
|
+
if verbose > 1:
|
|
963
|
+
print(
|
|
964
|
+
f"[run_aligned-ep] +plh: {node.name}={string_type(t, **str_kws)}"
|
|
965
|
+
)
|
|
966
|
+
yield RunAlignedRecord(
|
|
967
|
+
ep_id_node=i_torch,
|
|
968
|
+
ep_name=node.name,
|
|
969
|
+
ep_target="placeholder",
|
|
970
|
+
ep_shape_type=string_type(t, **str_kws),
|
|
971
|
+
).check(already_yielded)
|
|
972
|
+
continue
|
|
332
973
|
|
|
333
974
|
outputs = [node.name] if isinstance(node.name, str) else list(node.name)
|
|
334
975
|
args, kwargs = prepare_args_kwargs(torch_results, node)
|
|
976
|
+
begin = time.perf_counter()
|
|
335
977
|
new_outputs = run_fx_node(node, args, kwargs)
|
|
336
|
-
|
|
978
|
+
duration = time.perf_counter() - begin
|
|
979
|
+
ep_durations[i_torch] = duration
|
|
980
|
+
if isinstance(new_outputs, (torch.Tensor, int, float, list, tuple)):
|
|
337
981
|
new_outputs = (new_outputs,)
|
|
338
982
|
|
|
339
983
|
if new_outputs is None:
|
|
@@ -342,99 +986,112 @@ def run_aligned(
|
|
|
342
986
|
|
|
343
987
|
for k, v in zip(outputs, new_outputs):
|
|
344
988
|
torch_results[k] = v
|
|
345
|
-
if verbose:
|
|
989
|
+
if verbose > 1:
|
|
346
990
|
for k, v in zip(outputs, new_outputs):
|
|
347
|
-
print(
|
|
348
|
-
f"[run_aligned] +torch {k}="
|
|
349
|
-
f"{string_type(v, with_shape=True, with_min_max=True)}"
|
|
350
|
-
)
|
|
991
|
+
print(f"[run_aligned-ep] +res: {k}={string_type(v, **str_kws)}")
|
|
351
992
|
|
|
352
993
|
max_pos = -2
|
|
353
994
|
for n in outputs:
|
|
354
|
-
if n in positions
|
|
355
|
-
|
|
995
|
+
if n in positions:
|
|
996
|
+
if "onnx" in positions[n]:
|
|
997
|
+
max_pos = max(max_pos, positions[n]["onnx"])
|
|
998
|
+
if "fx" in positions[n]:
|
|
999
|
+
if positions[n]["fx"] > i_torch:
|
|
1000
|
+
max_pos = -2
|
|
1001
|
+
break
|
|
356
1002
|
if max_pos == -2:
|
|
357
1003
|
# we skip.
|
|
358
1004
|
continue
|
|
359
1005
|
|
|
1006
|
+
next_to_visit = last_position
|
|
360
1007
|
for i_onnx in range(last_position, max_pos + 1):
|
|
1008
|
+
if i_onnx in already_run:
|
|
1009
|
+
continue
|
|
1010
|
+
# The onnx node may produce more than one output, in that
|
|
1011
|
+
# case, we need to check the exported program is not behind.
|
|
361
1012
|
node = onx.graph.node[i_onnx]
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
for o, r in zip(node.output, res):
|
|
371
|
-
onnx_results[o] = r
|
|
372
|
-
if verbose:
|
|
373
|
-
print(
|
|
374
|
-
f"[run_aligned] +onnx {o}="
|
|
375
|
-
f"{string_type(r, with_shape=True, with_min_max=True)}"
|
|
376
|
-
)
|
|
1013
|
+
ep_behind = False
|
|
1014
|
+
for iname in node.output:
|
|
1015
|
+
if iname in positions and "fx" in positions[iname]:
|
|
1016
|
+
if positions[iname]["fx"] > i_torch:
|
|
1017
|
+
ep_behind = True
|
|
1018
|
+
break
|
|
1019
|
+
if ep_behind:
|
|
1020
|
+
break
|
|
377
1021
|
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
1022
|
+
for r in _loop_onnx_node(
|
|
1023
|
+
onx,
|
|
1024
|
+
ep_graph_nodes,
|
|
1025
|
+
onnx_results,
|
|
1026
|
+
mapping_onnx_to_torch,
|
|
1027
|
+
torch_results,
|
|
1028
|
+
ep_durations,
|
|
1029
|
+
use_tensor,
|
|
1030
|
+
i_torch,
|
|
1031
|
+
i_onnx,
|
|
1032
|
+
name_to_ep_node,
|
|
1033
|
+
run_cls_kwargs,
|
|
1034
|
+
str_kws,
|
|
1035
|
+
status,
|
|
1036
|
+
already_run,
|
|
1037
|
+
torch_names_to_onnx_names,
|
|
1038
|
+
verbose,
|
|
1039
|
+
exc,
|
|
1040
|
+
atol,
|
|
1041
|
+
rtol,
|
|
1042
|
+
reset_names,
|
|
1043
|
+
replay_configuration,
|
|
1044
|
+
has_cuda,
|
|
1045
|
+
run_cls,
|
|
1046
|
+
loop,
|
|
1047
|
+
run_onnx_with_torch_inputs,
|
|
1048
|
+
):
|
|
1049
|
+
if r:
|
|
1050
|
+
yield r.check(already_yielded)
|
|
1051
|
+
next_to_visit = i_onnx + 1
|
|
400
1052
|
|
|
401
|
-
last_position =
|
|
1053
|
+
last_position = next_to_visit
|
|
402
1054
|
|
|
403
1055
|
# complete the execution of the onnx graph
|
|
1056
|
+
if verbose > 1:
|
|
1057
|
+
print(
|
|
1058
|
+
f"[run_aligned] complete execution of onnx graph from pos={last_position} "
|
|
1059
|
+
f"to {len(onx.graph.node)}"
|
|
1060
|
+
)
|
|
404
1061
|
for i_onnx in range(last_position, len(onx.graph.node)):
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
1062
|
+
if i_onnx in already_run:
|
|
1063
|
+
continue
|
|
1064
|
+
for r in _loop_onnx_node(
|
|
1065
|
+
onx,
|
|
1066
|
+
ep_graph_nodes,
|
|
1067
|
+
onnx_results,
|
|
1068
|
+
mapping_onnx_to_torch,
|
|
1069
|
+
torch_results,
|
|
1070
|
+
ep_durations,
|
|
1071
|
+
use_tensor,
|
|
1072
|
+
i_torch,
|
|
1073
|
+
i_onnx,
|
|
1074
|
+
name_to_ep_node,
|
|
1075
|
+
run_cls_kwargs,
|
|
1076
|
+
str_kws,
|
|
1077
|
+
status,
|
|
1078
|
+
already_run,
|
|
1079
|
+
torch_names_to_onnx_names,
|
|
1080
|
+
verbose,
|
|
1081
|
+
exc,
|
|
1082
|
+
atol,
|
|
1083
|
+
rtol,
|
|
1084
|
+
reset_names,
|
|
1085
|
+
replay_configuration,
|
|
1086
|
+
has_cuda,
|
|
1087
|
+
run_cls,
|
|
1088
|
+
loop,
|
|
1089
|
+
run_onnx_with_torch_inputs,
|
|
1090
|
+
):
|
|
1091
|
+
if r:
|
|
1092
|
+
yield r.check(already_yielded)
|
|
421
1093
|
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
if o == to:
|
|
427
|
-
print(f"[run_aligned] =common results* {to}: {string_diff(d)}")
|
|
428
|
-
else:
|
|
429
|
-
print(f"[run_aligned] =common results* {to}/{o}: {string_diff(d)}")
|
|
430
|
-
if not (
|
|
431
|
-
atol is None or rtol is None or (d["abs"] <= atol and d["rel"] <= rtol)
|
|
432
|
-
):
|
|
433
|
-
skw = dict(with_shape=True, with_min_max=True)
|
|
434
|
-
raise ValueError(
|
|
435
|
-
f"discrepancies detected for results* [{to}/{o}]: {string_diff(d)}"
|
|
436
|
-
f"\n-- torch_results: {string_type(torch_results[to], **skw)}"
|
|
437
|
-
f"\n-- onnx_results: {string_type(r, **skw)}"
|
|
438
|
-
f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}"
|
|
439
|
-
)
|
|
440
|
-
yield (i, i_onnx, o, to, d)
|
|
1094
|
+
if loop is not None:
|
|
1095
|
+
loop.close()
|
|
1096
|
+
if verbose:
|
|
1097
|
+
print(f"[run_aligned] done with status={status.to_str()}")
|