onnx-diagnostic 0.8.1__py3-none-any.whl → 0.8.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +387 -12
  3. onnx_diagnostic/export/api.py +118 -5
  4. onnx_diagnostic/export/control_flow.py +214 -0
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +135 -0
  7. onnx_diagnostic/export/onnx_plug.py +396 -0
  8. onnx_diagnostic/ext_test_case.py +118 -25
  9. onnx_diagnostic/helpers/cache_helper.py +218 -204
  10. onnx_diagnostic/helpers/dot_helper.py +210 -0
  11. onnx_diagnostic/helpers/helper.py +92 -26
  12. onnx_diagnostic/helpers/log_helper.py +26 -4
  13. onnx_diagnostic/helpers/mini_onnx_builder.py +57 -3
  14. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  15. onnx_diagnostic/helpers/onnx_helper.py +115 -16
  16. onnx_diagnostic/helpers/ort_session.py +37 -11
  17. onnx_diagnostic/helpers/rt_helper.py +547 -0
  18. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  19. onnx_diagnostic/helpers/torch_helper.py +108 -6
  20. onnx_diagnostic/reference/ort_evaluator.py +233 -28
  21. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  22. onnx_diagnostic/tasks/image_text_to_text.py +5 -1
  23. onnx_diagnostic/tasks/summarization.py +72 -137
  24. onnx_diagnostic/torch_export_patches/eval/model_cases.py +28 -0
  25. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +1 -1
  26. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +11 -7
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +235 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  34. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  35. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  36. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +680 -0
  37. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  38. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  39. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  40. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  41. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +65 -2107
  42. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +53 -0
  43. onnx_diagnostic/torch_models/hghub/model_inputs.py +15 -2
  44. onnx_diagnostic/torch_models/validate.py +50 -1
  45. onnx_diagnostic/torch_onnx/sbs.py +963 -312
  46. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +491 -0
  47. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/METADATA +1 -1
  48. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/RECORD +51 -30
  49. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/WHEEL +0 -0
  50. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/licenses/LICENSE.txt +0 -0
  51. {onnx_diagnostic-0.8.1.dist-info → onnx_diagnostic-0.8.3.dist-info}/top_level.txt +0 -0
@@ -1,158 +1,580 @@
1
- from typing import Any, Dict, Iterator, Optional, Tuple, Union
1
+ import inspect
2
+ import time
3
+ from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Sequence, Tuple, Union
2
4
  import onnx
5
+ import onnx.helper as oh
6
+ import numpy as np
3
7
  import torch
4
- from ..helpers import string_type, string_diff, max_diff
5
- from ..helpers.onnx_helper import to_array_extended
6
- from ..helpers.torch_helper import to_numpy
8
+ from ..helpers import string_type, string_diff, max_diff, flatten_object
9
+ from ..helpers.onnx_helper import pretty_onnx
10
+ from ..helpers.torch_helper import to_numpy, from_numpy, to_tensor, torch_dtype_to_onnx_dtype
11
+ from ..helpers.torch_fx_graph_helper import prepare_args_kwargs, run_fx_node
12
+ from ..reference.ort_evaluator import OnnxList, OnnxruntimeEvaluator
13
+ from .sbs_dataclasses import (
14
+ ReplayConfiguration,
15
+ RunAlignedRecord,
16
+ StatusRunAligned,
17
+ make_torch_inputs,
18
+ )
7
19
 
8
20
 
9
- def validate_fx_tensor(
10
- node: torch.fx.Node, tensor: torch.Tensor, expected_shape: Tuple[Any, ...]
11
- ) -> None:
12
- """
13
- Validates the shape of tensor is expected.
21
+ def _check_tensor_(use_tensor, name, obj, flip_type=False):
22
+ if flip_type:
23
+ if use_tensor:
24
+ if isinstance(obj, np.ndarray):
25
+ obj = from_numpy(obj)
26
+ else:
27
+ if isinstance(obj, torch.Tensor):
28
+ obj = to_numpy(obj)
14
29
 
15
- :param node: node
16
- :param tensor: tensor
17
- :param expected_shape: expected shape
18
- """
19
- assert len(tensor.shape) == len(expected_shape), (
20
- f"Shape mismatch, got {tensor.shape} expected {expected_shape}, "
21
- f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
22
- f"node.args={node.args}, node.kwargs={node.kwargs}, "
23
- f"node.meta={node.meta}"
30
+ assert not use_tensor or isinstance(obj, (torch.Tensor, OnnxList)), (
31
+ f"Unexpected type {type(obj)} for {name!r}. "
32
+ f"use_tensor is True so torch.Tensor is expected."
33
+ )
34
+ assert use_tensor or isinstance(obj, (np.ndarray, OnnxList)), (
35
+ f"Unexpected type {type(obj)} for {name!r}. "
36
+ f"use_tensor is False so np.array is expected."
24
37
  )
25
- for a, b in zip(tensor.shape, expected_shape):
26
- assert not isinstance(b, int) or a == b or {a, b} == {0, 1}, (
27
- f"Dimension mismatch, got {tensor.shape} expected {expected_shape}, "
28
- f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
29
- f"node.args={node.args}, node.kwargs={node.kwargs}, "
30
- f"node.meta={node.meta}"
38
+ return obj
39
+
40
+
41
+ def _make_node_from_initializer(proto: onnx.TensorProto) -> onnx.NodeProto:
42
+ return oh.make_node("Constant", [], [proto.name], value=proto)
43
+
44
+
45
+ def _loop_cmp(
46
+ mapping_onnx_to_torch: Dict[str, str],
47
+ torch_results: Dict[str, torch.Tensor],
48
+ onnx_results: Dict[str, Any],
49
+ onnx_name: str,
50
+ onnx_result: torch.Tensor,
51
+ second_onnx_result: torch.Tensor,
52
+ verbose: int,
53
+ atol: Optional[float],
54
+ rtol: Optional[float],
55
+ i_torch: int,
56
+ i_onnx: int,
57
+ str_kws: Dict[str, bool],
58
+ exc: bool,
59
+ use_tensor: bool,
60
+ ) -> Optional[RunAlignedRecord]:
61
+ onnx_results[onnx_name] = _check_tensor_(use_tensor, onnx_name, onnx_result)
62
+ if verbose > 1:
63
+ print(f"[run_aligned-nx] +res: {onnx_name}={string_type(onnx_result, **str_kws)}")
64
+
65
+ to = mapping_onnx_to_torch.get(onnx_name, onnx_name)
66
+ if to in torch_results:
67
+ d = max_diff(torch_results[to], onnx_result, hist=[0.1, 0.01])
68
+ if verbose > 1:
69
+ if onnx_name == to:
70
+ print(f"[run_aligned-==] cmp {to}: {string_diff(d)}")
71
+ else:
72
+ print(f"[run_aligned-~~] cmd {to}/{onnx_name}: {string_diff(d)}")
73
+ if not (atol is None or rtol is None or (d["abs"] <= atol and d["rel"] <= rtol)):
74
+ if exc:
75
+ raise ValueError(
76
+ f"discrepancies detected for results [{to}/{onnx_name}]: "
77
+ f"{string_diff(d)}"
78
+ f"\n-- onnx_result: {string_type(onnx_result[to], **str_kws)}"
79
+ f"\n-- onnx_results: {string_type(onnx_result, **str_kws)}"
80
+ f"\n-- torch\n{onnx_result[to]}"
81
+ )
82
+ else:
83
+ print(
84
+ f"[run_align-dx] discrepancies {string_diff(d)} - [{to}/{onnx_name}]"
85
+ )
86
+ r = RunAlignedRecord(
87
+ ep_id_node=i_torch,
88
+ onnx_id_node=i_onnx,
89
+ ep_name=to,
90
+ onnx_name=onnx_name,
91
+ ep_shape_type=string_type(torch_results[to], **str_kws),
92
+ onnx_shape_type=string_type(onnx_result, **str_kws),
31
93
  )
94
+ r.set_diff(d)
95
+ if second_onnx_result is not None:
96
+ d2 = max_diff(torch_results[to], second_onnx_result, hist=[0.1, 0.01])
97
+ r.set_diff2(d2)
98
+ mapping_onnx_to_torch[onnx_name] = to
99
+ return r
100
+ return None
32
101
 
33
102
 
34
- def validate_fx_outputs(node: torch.fx.Node, outputs: Tuple[Any, ...]) -> None:
35
- """
36
- Validates the outputs of a node using metadata stored in the node.
103
+ def _duplicated_values(d):
104
+ rev = {}
105
+ for k, v in d.items():
106
+ if v in rev:
107
+ rev[v].append(k)
108
+ else:
109
+ rev[v] = [k]
110
+ res = {k: v for k, v in rev.items() if len(v) > 1}
111
+ final = set()
112
+ for v in res.values():
113
+ final |= set(v)
114
+ return final
37
115
 
38
- :param node: node
39
- :param outputs: outputs
40
- """
41
- if "val" not in node.meta:
42
- return
43
- if isinstance(outputs, torch.Tensor):
44
- validate_fx_tensor(node, outputs, node.meta["val"].shape)
45
- return
46
- if isinstance(outputs, (tuple, list)):
47
- assert isinstance(node.meta["val"], (list, tuple)), (
48
- f"Unexpected type {string_type(node.meta['val'])} for node.meta['val'], "
49
- f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
50
- f"node.args={node.args}, node.kwargs={node.kwargs}, "
51
- f"node.meta={node.meta}"
116
+
117
+ def _validation_nn_functional(
118
+ node: onnx.NodeProto, new_feeds: Dict[str, torch.Tensor], expected: List[torch.Tensor]
119
+ ) -> Optional[str]:
120
+ if node.op_type == "Gemm" and len(node.input) == 3:
121
+ atts = {}
122
+ for att in node.attribute:
123
+ if att.name in ("alpha", "beta"):
124
+ atts[att.name] = att.f
125
+ elif att.name in ("transA", "transB"):
126
+ atts[att.name] = att.i
127
+ if atts == {"transB": 1}:
128
+ res = torch.nn.functional.linear(*[new_feeds[i] for i in node.input])
129
+ diff = max_diff(res, expected[0])
130
+ return f"function.linear:{string_diff(diff)}"
131
+ return None
132
+
133
+
134
+ def _loop_onnx_node(
135
+ onx: onnx.ModelProto,
136
+ ep_graph_nodes: List[torch.fx.Node],
137
+ onnx_results: Dict[str, Any],
138
+ mapping_onnx_to_torch: Dict[str, str],
139
+ torch_results: Dict[str, torch.Tensor],
140
+ ep_durations,
141
+ use_tensor: bool,
142
+ i_torch: int,
143
+ i_onnx: int,
144
+ name_to_ep_node: Dict[str, int],
145
+ run_cls_kwargs: Dict[str, Any],
146
+ str_kws: Dict[str, bool],
147
+ status: StatusRunAligned,
148
+ already_run_onnx: Set[int],
149
+ torch_names_to_onnx_names: Dict[str, str],
150
+ verbose: int,
151
+ exc: bool,
152
+ atol: float,
153
+ rtol: float,
154
+ reset_names: Set[str],
155
+ replay_configuration: Optional[ReplayConfiguration],
156
+ has_cuda: bool,
157
+ run_cls: type,
158
+ loop: Any,
159
+ run_onnx_with_torch_inputs: bool,
160
+ ) -> Iterator[Optional[RunAlignedRecord]]:
161
+
162
+ if i_onnx in already_run_onnx:
163
+ yield None
164
+ node = onx.graph.node[i_onnx]
165
+ if verbose > 1:
166
+ print(
167
+ f"[run_aligned] run onx.graph.node[{i_onnx}]: "
168
+ f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
52
169
  )
53
- assert len(outputs) == len(node.meta["val"]), (
54
- f"Length mismatch, got {len(outputs)} expected {len(node.meta['val'])}, "
55
- f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
56
- f"node.args={node.args}, node.kwargs={node.kwargs}, "
57
- f"node.meta={node.meta}"
170
+ elif verbose == 1:
171
+ loop.set_description(
172
+ f"ep {i_torch}/{len(ep_graph_nodes)} nx {i_onnx}/{len(onx.graph.node)} "
173
+ f"{status.to_str()}"
58
174
  )
59
- for a, b in zip(outputs, node.meta["val"]):
60
- validate_fx_tensor(node, a, b.shape)
61
- return
62
- if isinstance(outputs, int):
63
- assert (
64
- isinstance(node.meta["val"], (torch.SymInt, torch.SymBool, torch.SymFloat))
65
- or outputs == node.meta["val"]
66
- ), (
67
- f"Int mismatch, got {outputs} expected {node.meta['val']}, "
68
- f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
69
- f"node.args={node.args}, node.kwargs={node.kwargs}, "
70
- f"node.meta={node.meta}"
175
+ loop.update(min(1, 1 + i_torch + i_onnx))
176
+
177
+ ref = run_cls(node, **run_cls_kwargs)
178
+ # We need to clone because the runtime maybe using dlpack to create OrtValue
179
+ hidden_inputs = OnnxruntimeEvaluator._get_hidden_node_inputs(node)
180
+ all_inputs = [*node.input, *hidden_inputs] if hidden_inputs else node.input
181
+ feeds = (
182
+ {k: onnx_results[k].clone() for k in all_inputs if k}
183
+ if use_tensor
184
+ else {k: onnx_results[k].copy() for k in all_inputs if k}
185
+ )
186
+ assert "" not in feeds, f"Unexpected feeds={string_type(feeds, **str_kws)}"
187
+ if verbose > 1:
188
+ print(f"[run_aligned] feeds={string_type(feeds, **str_kws)}")
189
+ begin = time.perf_counter()
190
+ try:
191
+ res = ref.run(None, feeds) # type: ignore[attr-defined]
192
+ except Exception as e:
193
+ raise RuntimeError(
194
+ f"Unable to run node {node.op_type}, domain={node.domain} "
195
+ f"with inputs={node.input}, feeds={string_type(feeds, **str_kws)}"
196
+ ) from e
197
+ duration = time.perf_counter() - begin
198
+ if verbose > 1:
199
+ print(f"[run_aligned] res={string_type(res, **str_kws)}")
200
+ assert (
201
+ not has_cuda
202
+ or not any(t is not None and t.is_cuda for t in feeds.values())
203
+ or any(t is not None and t.is_cuda for t in res)
204
+ or node.op_type in {"Shape", "Size"} # on CPU no matter what
205
+ or node.op_type
206
+ in {
207
+ "Add",
208
+ "Concat",
209
+ "Div",
210
+ "Gather",
211
+ "Mul",
212
+ "Range",
213
+ "Squeeze",
214
+ "Sub",
215
+ "Unsqueeze",
216
+ } # not sure, could be about shapes
217
+ ), (
218
+ f"One input is on cuda but there is no float output on cuda, "
219
+ f"feeds={string_type(feeds, with_device=True, with_shape=True)}, "
220
+ f"res={string_type(res, with_device=True, with_shape=True)}, "
221
+ f"node is {pretty_onnx(node)}"
222
+ )
223
+
224
+ comment = None
225
+ cross = None
226
+ if run_onnx_with_torch_inputs:
227
+ # Let's run the operator with torch results if they are available
228
+ new_feeds, removed = make_torch_inputs(
229
+ node.input,
230
+ {
231
+ **{v: k for k, v in torch_names_to_onnx_names.items()},
232
+ **mapping_onnx_to_torch,
233
+ },
234
+ onnx_results,
235
+ torch_results,
236
+ submodel=None,
71
237
  )
72
- return
73
- if outputs is None:
74
- assert node.meta["val"] is None, (
75
- f"None mismatch, got {outputs} expected {node.meta['val']}, "
76
- f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
77
- f"node.args={node.args}, node.kwargs={node.kwargs}, "
78
- f"node.meta={node.meta}"
238
+ if not removed:
239
+ if verbose > 1:
240
+ print(
241
+ f"[run_aligned] feeds for second run="
242
+ f"{string_type(new_feeds, **str_kws)}"
243
+ )
244
+ cross = ref.run(None, new_feeds)
245
+ if verbose > 1:
246
+ print(f"[run_aligned] got for second run={string_type(cross, **str_kws)}")
247
+ # Gemm = torch.nn.function.linear, in that case, we just run it as well
248
+ to = mapping_onnx_to_torch.get(node.output[0], node.output[0])
249
+ if to in torch_results:
250
+ comment = _validation_nn_functional(node, new_feeds, [torch_results[to]])
251
+ elif verbose > 1:
252
+ print(f"[run_aligned] second run not possible because of missing {removed}")
253
+
254
+ if cross is None:
255
+ cross = [None for _ in res]
256
+
257
+ list_node_output = list(node.output)
258
+ node_output = [o for o in list_node_output if o]
259
+ for o, r, r2 in zip(node_output, res, cross):
260
+ if r is None or not o:
261
+ continue
262
+ tmp = _loop_cmp(
263
+ mapping_onnx_to_torch,
264
+ torch_results,
265
+ onnx_results,
266
+ o,
267
+ r,
268
+ r2,
269
+ verbose,
270
+ atol,
271
+ rtol,
272
+ i_torch,
273
+ i_onnx,
274
+ str_kws,
275
+ exc,
276
+ use_tensor,
79
277
  )
80
- return
81
- raise NotImplementedError(
82
- f"Validation for output type {type(outputs)} is not implemented, "
83
- f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
84
- f"node.args={node.args}, node.kwargs={node.kwargs}, "
85
- f"node.meta={node.meta}"
86
- )
278
+ if tmp is not None:
279
+ if tmp.ep_name in name_to_ep_node:
280
+ tmp.ep_id_node = name_to_ep_node[tmp.ep_name]
281
+ tmp.ep_target = str(ep_graph_nodes[tmp.ep_id_node].target)
282
+ tmp.ep_time_run = ep_durations[tmp.ep_id_node]
283
+ else:
284
+ tmp.ep_id_node = None
285
+ tmp.ep_target = None
286
+ tmp.ep_name = None
287
+ tmp.onnx_op_type = onx.graph.node[tmp.onnx_id_node].op_type
288
+ tmp.onnx_id_output = list_node_output.index(o)
289
+ tmp.onnx_time_run = duration
290
+ status.yielded_nodes += 1
291
+ if tmp.err_abs is not None:
292
+ status.update(tmp.err_abs)
293
+ tmp.comment = comment
294
+ yield tmp
87
295
 
296
+ # do we need to dump pieces if graph the user can replay?
297
+ if replay_configuration:
298
+ if replay_configuration.select(
299
+ name=tmp.onnx_name, op_type=tmp.onnx_op_type, err_abs=tmp.err_abs
300
+ ):
301
+ replay_configuration.dump(
302
+ name=tmp.onnx_name,
303
+ onnx_id_node=tmp.onnx_id_node,
304
+ model=onx,
305
+ onnx_results=onnx_results,
306
+ torch_results=torch_results,
307
+ onnx_name_to_ep_name={
308
+ **{v: k for k, v in torch_names_to_onnx_names.items()},
309
+ **mapping_onnx_to_torch,
310
+ },
311
+ verbose=max(verbose - 1, 0),
312
+ )
313
+ status.last_replay = tmp.onnx_name
88
314
 
89
- def run_fx_node(
90
- node: torch.fx.Node, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
91
- ) -> Tuple[Any, ...]:
92
- """
93
- Executes a node
315
+ # reset_names: replaces onnx_results by torch_results to see
316
+ # if that fixes the discrepancies problem
317
+ if reset_names and tmp.ep_name in reset_names:
318
+ assert (
319
+ tmp.ep_name in torch_results
320
+ ), f"name {tmp.ep_name!r} set to be reset is missing in torch_results."
321
+ assert (
322
+ tmp.onnx_name in onnx_results
323
+ ), f"name {tmp.onnx_name!r} set to be reset is missing in onnx_results."
324
+ onnx_results[tmp.onnx_name] = torch_results[tmp.ep_name]
325
+ tmp = _loop_cmp(
326
+ mapping_onnx_to_torch,
327
+ torch_results,
328
+ onnx_results,
329
+ o,
330
+ torch_results[tmp.ep_name],
331
+ None,
332
+ verbose,
333
+ atol,
334
+ rtol,
335
+ i_torch,
336
+ i_onnx,
337
+ str_kws,
338
+ exc,
339
+ use_tensor,
340
+ )
341
+ assert tmp.err_abs == 0, f"Reset did not happen, tmp={tmp}"
342
+ if tmp is not None:
343
+ tmp.onnx_op_type = "reset"
344
+ tmp.onnx_id_output = list_node_output.index(o)
345
+ status.yielded_nodes += 1
346
+ yield tmp
347
+ already_run_onnx.add(i_onnx)
94
348
 
95
- :param node: runs a node
96
- :param args: unnamed inputs to the node
97
- :param kwargs: named inputs to the node
98
- :return: results
99
- """
100
- if node.op == "output":
101
- assert len(args) == 1 and not kwargs, (
102
- f"Unexpected inputs: args={string_type(args, limit=20)} "
103
- f"kwargs={string_type(kwargs, limit=20)}"
349
+
350
+ def _preparation_with_fx_graph(
351
+ ep_graph_nodes: List[torch.fx.Node],
352
+ name_to_ep_node: Dict[str, int],
353
+ torch_input_names: List[str],
354
+ onx: onnx.ModelProto,
355
+ torch_names_to_onnx_names: Dict[str, str],
356
+ skip_mapping_torch_onnx,
357
+ torch_results: Dict[str, torch.Tensor],
358
+ placeholders: Dict[str, torch.Tensor],
359
+ placeholders_to_state_dict: Dict[str, str],
360
+ ep_state_dict: Dict[str, torch.Tensor],
361
+ positions: Dict[str, Dict[str, int]],
362
+ ) -> List[str]:
363
+ torch_output_names = None
364
+ for i, node in enumerate(ep_graph_nodes):
365
+ if isinstance(node.name, str):
366
+ positions[node.name] = dict(fx=i)
367
+ else:
368
+ for n in node.name:
369
+ positions[n] = dict(fx=i)
370
+ if node.op == "placeholder":
371
+ if node.name in placeholders_to_state_dict:
372
+ # This a weight.
373
+ placeholders[node.name] = ep_state_dict[placeholders_to_state_dict[node.name]]
374
+ torch_results[node.name] = placeholders[node.name]
375
+ assert isinstance(torch_results[node.name], torch.Tensor), (
376
+ f"torch_results[{node.name}] not a tensor but "
377
+ f"{type(torch_results[node.name])}"
378
+ )
379
+ else:
380
+ # This is an input
381
+ assert len(torch_input_names) < len(onx.graph.input), (
382
+ f"torch_input_names={torch_input_names!r}, "
383
+ f"onnx_input_names={[n.name for n in onx.graph.input]}, "
384
+ f"node.name={node.name!r} cannot be an input"
385
+ )
386
+ assert node.name not in skip_mapping_torch_onnx, (
387
+ f"{node.name!r} is ambiguous, cannot be mapped due to "
388
+ f"{skip_mapping_torch_onnx}"
389
+ )
390
+ torch_names_to_onnx_names[node.name] = onx.graph.input[
391
+ len(torch_input_names)
392
+ ].name
393
+ torch_input_names.append(node.name)
394
+ elif node.op == "output":
395
+ torch_output_names = [n.name for n in node.args[0]]
396
+ assert isinstance(node.name, str), (
397
+ f"Unexpected type {type(node.name)} for node={node} (target={node.target}), "
398
+ f"args={node.args}"
104
399
  )
105
- return args
106
- if node.op == "call_function":
107
- assert callable(node.target), f"{node.target!r} not callable in node {node!r}"
108
- outputs = node.target(*args, **(kwargs or {}))
109
- validate_fx_outputs(node, outputs)
110
- return outputs
111
- raise NotImplementedError(
112
- f"node.op={node.op!r} is not implemented, node.name={node.name!r}"
400
+ name_to_ep_node[node.name] = i
401
+ assert torch_output_names is not None, "No output node ws found the graph."
402
+ return torch_output_names
403
+
404
+
405
+ def _preparation_with_onnx_model(
406
+ default_device,
407
+ use_tensor: bool,
408
+ onx: onnx.ModelProto,
409
+ already_yielded: Dict[str, Any],
410
+ str_kws: Dict[str, bool],
411
+ positions: Dict[str, Dict[str, int]],
412
+ torch_names_to_onnx_names: Dict[str, str],
413
+ torch_output_names: List[str],
414
+ torch_results: Dict[str, torch.Tensor],
415
+ skip_mapping_torch_onnx: Set[str],
416
+ verbose: int,
417
+ args: Sequence[Any],
418
+ kwargs: Dict[str, Any],
419
+ ) -> Tuple[Dict[str, str], Dict[str, torch.Tensor], float, float, List[RunAlignedRecord]]:
420
+ for inp in onx.graph.input:
421
+ n = inp.name
422
+ if n in positions:
423
+ positions[n]["onnx"] = -1
424
+ else:
425
+ positions[n] = dict(onnx=-1)
426
+ for inp in onx.graph.initializer:
427
+ n = inp.name
428
+ if n in positions:
429
+ positions[n]["onnx"] = -1
430
+ else:
431
+ positions[n] = dict(onnx=-1)
432
+ for i, node in enumerate(onx.graph.node):
433
+ for n in node.output:
434
+ if n in positions:
435
+ positions[n]["onnx"] = i
436
+ else:
437
+ positions[n] = dict(onnx=i)
438
+
439
+ onnx_outputs_names = [o.name for o in onx.graph.output]
440
+ assert torch_output_names is not None and len(torch_output_names) == len(
441
+ onnx_outputs_names
442
+ ), (
443
+ f"Unexpected number of outputs, torch_output_names={torch_output_names}, "
444
+ f"onnx_outputs_names={onnx_outputs_names}"
113
445
  )
446
+ mapping_onnx_to_torch = dict(zip(onnx_outputs_names, torch_output_names))
114
447
 
448
+ onnx_args = list(args) if args else []
449
+ if kwargs:
450
+ onnx_args.extend(flatten_object(kwargs, drop_keys=True))
451
+ if verbose:
452
+ print(f"[run_aligned] args: {string_type(args, **str_kws)}")
453
+ print(f"[run_aligned] kwargs: {string_type(kwargs, **str_kws)}")
454
+ print(f"[run_aligned] onnx: {string_type(onnx_args, **str_kws)}")
455
+ print(f"[run_aligned] nx: walks through {len(onx.graph.input)} onnx inputs")
456
+ onnx_results: Dict[str, Any] = {}
457
+ for inp, v in zip(onx.graph.input, onnx_args):
458
+ onnx_results[inp.name] = _check_tensor_(
459
+ use_tensor, inp.name, v if use_tensor else to_numpy(v)
460
+ )
461
+ if verbose:
462
+ print(f"[run_aligned-nx] +inp: {inp.name}: {string_type(v, **str_kws)}")
115
463
 
116
- def _pick_result(torch_results: Dict[str, Any], ref: Any) -> Any:
117
- "See :func:`prepare_args_kwargs`."
118
- if isinstance(ref, torch.fx.Node):
119
- return torch_results[ref.name]
120
- if isinstance(ref, list):
121
- return [_pick_result(torch_results, n) for n in ref]
122
- if isinstance(ref, tuple):
123
- return tuple(_pick_result(torch_results, n) for n in ref)
124
- if isinstance(ref, dict):
125
- return {k: _pick_result(torch_results, v) for k, v in ref.items()}
126
- if isinstance(ref, (bool, int, float, str, torch.device, torch.dtype)):
127
- return ref
128
- if ref is None:
129
- return None
130
- raise NotImplementedError(f"Unable to process args type {type(ref)}")
131
-
132
-
133
- def prepare_args_kwargs(
134
- torch_results: Dict[str, Any], node: torch.fx.Node
135
- ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
136
- """
137
- Prepares args and kwargs before executing a fx node.
464
+ # alias for initializers
465
+ skip_onnx_name = set()
466
+ init_aliases: Dict[str, str] = {}
467
+ for init in onx.graph.initializer:
468
+ new_names = {
469
+ n
470
+ for n in [
471
+ f"p_{init.name.replace('.', '_')}",
472
+ f"p_{init.name.split('::')[0].split('--')[-1].replace('.', '_')}",
473
+ f"{init.name.split('::')[0].split('--')[-1].replace('.', '_')}",
474
+ ]
475
+ if n != init.name
476
+ }
477
+ drop = False
478
+ for new_name in new_names:
479
+ if new_name in skip_onnx_name:
480
+ drop = True
481
+ break
482
+ if drop:
483
+ skip_onnx_name |= new_names | {init.name}
484
+ for new_name in new_names:
485
+ if new_names in init_aliases:
486
+ del init_aliases[new_name]
487
+ else:
488
+ for new_name in new_names:
489
+ init_aliases[new_name] = init.name
490
+ rev_init_aliases: Dict[str, Set[str]] = {}
491
+ for k, v in init_aliases.items():
492
+ if v in rev_init_aliases:
493
+ rev_init_aliases[v].add(k)
494
+ else:
495
+ rev_init_aliases[v] = {k}
138
496
 
139
- :param torch_results: existing results
140
- :param node: node to execute
141
- :return: new args and kwargs
142
- """
143
- new_args = _pick_result(torch_results, node.args)
144
- new_kwargs = _pick_result(torch_results, node.kwargs)
145
- return new_args, new_kwargs
497
+ # initializers
498
+ if verbose:
499
+ print(f"[run_aligned] nx: handles {len(onx.graph.initializer)} initializers from onnx")
500
+ memory_cpu = 0
501
+ memory_cuda = 0
502
+ records_to_yield = []
503
+ for init in onx.graph.initializer: # type: ignore
504
+ t = None
505
+ if init.name in torch_results:
506
+ if init.name not in skip_mapping_torch_onnx:
507
+ t = torch_results[init.name]
508
+ torch_names_to_onnx_names[init.name] = init.name
509
+ elif init.name not in skip_onnx_name and init.name in rev_init_aliases:
510
+ new_names = [
511
+ k
512
+ for k in rev_init_aliases[init.name]
513
+ if k in torch_results and k not in skip_mapping_torch_onnx
514
+ ]
515
+ if new_names and len(new_names) == 1:
516
+ new_name = new_names[0] # type: ignore[assignment, index]
517
+ t = torch_results[new_name]
518
+ if (
519
+ len(set(t.shape)) == len(t.shape) # not repeated dimension
520
+ and t.shape == tuple(init.dims)
521
+ and torch_dtype_to_onnx_dtype(t.dtype) == init.data_type
522
+ ):
523
+ torch_names_to_onnx_names[new_name] = init.name
524
+ else:
525
+ t = None
526
+
527
+ # We should check tensors and proto are the same.
528
+ if t is None:
529
+ t = to_tensor(init)
530
+ if default_device and t.numel() >= 1024:
531
+ # Let's force its way to cuda (should check the device has well).
532
+ t = t.to(default_device)
533
+ records_to_yield.append(
534
+ RunAlignedRecord(
535
+ onnx_id_node=-1,
536
+ onnx_name=init.name,
537
+ onnx_op_type="initializer",
538
+ onnx_shape_type=string_type(t, **str_kws),
539
+ ).check(already_yielded)
540
+ )
541
+
542
+ size = t.element_size() * t.numel()
543
+ if t.is_cuda:
544
+ memory_cuda += size
545
+ else:
546
+ memory_cpu += size
547
+ if init.name not in onnx_results:
548
+ # otherwise, it is an input with a default value
549
+ onnx_results[init.name] = _check_tensor_(use_tensor, init.name, t, flip_type=True)
550
+ return mapping_onnx_to_torch, onnx_results, memory_cpu, memory_cuda, records_to_yield
146
551
 
147
552
 
148
553
  def run_aligned(
149
554
  ep: torch.export.ExportedProgram,
150
555
  onx: Union[onnx.ModelProto, onnx.FunctionProto],
151
- args: Tuple[torch.Tensor, ...],
152
- check_conversion_cls: Union[Dict[str, Any], type],
556
+ run_cls: Callable[
557
+ [
558
+ Union[
559
+ onnx.ModelProto,
560
+ onnx.FunctionProto,
561
+ onnx.GraphProto,
562
+ onnx.NodeProto,
563
+ ]
564
+ ],
565
+ List[Union[np.ndarray, torch.Tensor]],
566
+ ],
567
+ args: Optional[Tuple[torch.Tensor, ...]] = None,
153
568
  kwargs: Optional[Dict[str, Any]] = None,
569
+ use_tensor: bool = False,
570
+ atol: Optional[float] = None,
571
+ rtol: Optional[float] = None,
154
572
  verbose: int = 0,
155
- ) -> Iterator[Tuple[Any, ...]]:
573
+ exc: bool = True,
574
+ reset_names: Optional[List[str]] = None,
575
+ replay_configuration: Optional[ReplayConfiguration] = None,
576
+ run_onnx_with_torch_inputs: bool = False,
577
+ ) -> Iterator[RunAlignedRecord]:
156
578
  """
157
579
  Runs in parallel both the exported program
158
580
  and the onnx proto and looks for discrepancies.
@@ -162,11 +584,25 @@ def run_aligned(
162
584
 
163
585
  :param ep: exported program
164
586
  :param onx: model or function proto
587
+ :param run_cls: defines the runtime to use for this task
165
588
  :param args: input args
166
- :param check_conversion_cls: defines the runtime to use for this task
167
589
  :param kwargs: input kwargs
590
+ :param use_tensor: use torch tensors instead of numpy arrays
591
+ for the onnx runtime
592
+ :param atol: absolute tolerance
593
+ :param rtol: relative tolerance
168
594
  :param verbose: verbosity level
169
- :return: a list of tuples containing the results, they come in tuple,
595
+ :param exc: stops if an exception
596
+ :param reset_names: list of names, the onnx execution takes the torch outputs instead
597
+ of its own result if the names falls into that set
598
+ :param replay_configuration: configuration to let the user dump any problematic
599
+ piece of the onnx graph he wants to replay in order to investigate later,
600
+ see :class: `ReplayConfiguration
601
+ <onnx_diagnostic.torch_onnx.sbs.ReplayConfiguration>`
602
+ :param run_onnx_with_torch_inputs: run an onnx operator with torch results
603
+ if they available
604
+ :return: a list of :class:`RunAlignedRecord
605
+ <onnx_diagnostic.torch_onnx.sbs_dataclasses.RunAlignedRecord>`
170
606
 
171
607
  Example:
172
608
 
@@ -174,11 +610,10 @@ def run_aligned(
174
610
  :showcode:
175
611
  :warningout: UserWarning
176
612
 
177
- import pprint
178
613
  import pandas
179
614
  import torch
180
615
  from onnx_diagnostic.reference import (
181
- # This can be replace by any runtime taking NodeProto as an input.
616
+ # This can be replaced by any runtime taking NodeProto as an input.
182
617
  ExtendedReferenceEvaluator as ReferenceEvaluator,
183
618
  )
184
619
  from onnx_diagnostic.torch_onnx.sbs import run_aligned
@@ -193,13 +628,6 @@ def run_aligned(
193
628
  return ru
194
629
 
195
630
 
196
- def post_process(obs):
197
- dobs = dict(zip(["ep_id_node", "onnx_id_node", "ep_name", "onnx_name"], obs))
198
- dobs["err_abs"] = obs[-1]["abs"]
199
- dobs["err_rel"] = obs[-1]["rel"]
200
- return dobs
201
-
202
-
203
631
  x = torch.randn((5, 4))
204
632
  Model()(x) # to make sure the model is running
205
633
  ep = torch.export.export(
@@ -209,131 +637,341 @@ def run_aligned(
209
637
  Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},)
210
638
  ).model_proto
211
639
  results = list(
212
- map(
213
- post_process,
214
- run_aligned(
215
- ep,
216
- onx,
217
- (x,),
218
- check_conversion_cls=dict(cls=ReferenceEvaluator, atol=1e-5, rtol=1e-5),
219
- verbose=1,
220
- ),
221
- ),
640
+ run_aligned(ep, onx, ReferenceEvaluator, (x,), atol=1e-5, rtol=1e-5, verbose=1)
222
641
  )
223
642
  print("------------")
224
643
  print("final results")
225
644
  df = pandas.DataFrame(results)
645
+ df = df.apply(lambda col: col.fillna("") if col.dtype == "object" else col)
226
646
  print(df)
227
- """
228
- assert not kwargs, f"Not implemented when kwargs={string_type(kwargs,with_shape=True)}"
229
- cls, atol, rtol = (
230
- (
231
- check_conversion_cls["cls"],
232
- check_conversion_cls["atol"],
233
- check_conversion_cls["rtol"],
647
+
648
+ This example uses :class:`onnx.reference.ReferenceEvaluator` to run the onnx model
649
+ but onnxruntime can also be used through
650
+ :class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`.
651
+ It relies on :epkg:`onnxruntime` and selects CPU or CUDA depending
652
+ on the device where the inputs are located.
653
+
654
+ The :class:`torch.export.ExportedProgram` can be saved on disk
655
+ with ``ep.save("<filename>.pt")`` and restored with
656
+ ``torch.export.load("<filename>.pt")``. That leeds the input to save.
657
+ We can decouple the export and the alignment.
658
+
659
+ .. runpython::
660
+ :showcode:
661
+ :warningout: UserWarning
662
+
663
+ import onnx
664
+ import torch
665
+ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
666
+
667
+
668
+ class Model(torch.nn.Module):
669
+ def forward(self, x):
670
+ ry = x.abs()
671
+ rz = ry.exp()
672
+ rw = rz + 1
673
+ ru = rw.log() + rw
674
+ return ru
675
+
676
+
677
+ x = torch.randn((5, 4))
678
+ dynamic_shapes = ({0: "batch"},)
679
+ Model()(x) # to make sure the model is running
680
+ ep = torch.export.export(
681
+ Model(), (x,), dynamic_shapes=use_dyn_not_str(dynamic_shapes)
234
682
  )
235
- if isinstance(check_conversion_cls, dict)
236
- else (check_conversion_cls, None, None)
237
- )
683
+ onx = torch.onnx.export(
684
+ Model(), (x,), dynamic_shapes=dynamic_shapes
685
+ ).model_proto
238
686
 
239
- # retrieve the positions
240
- positions: Dict[str, Any] = {}
241
- for i, node in enumerate(ep.graph.nodes):
242
- if isinstance(node.name, str):
243
- positions[node.name] = dict(fx=i)
244
- else:
245
- for n in node.name:
246
- positions[n] = dict(fx=i)
687
+ torch.export.save(ep, "test_doc_sbs_example.pt2")
688
+ onnx.save(onx, "test_doc_sbs_example.onnx")
689
+ torch.save((x,), "test_doc_sbs_example.pt")
247
690
 
248
- for i, node in enumerate(onx.graph.node):
249
- for n in node.output:
250
- if n in positions:
251
- positions[n]["onnx"] = i
252
- else:
253
- positions[n] = dict(onnx=i)
691
+ Then we can restore all of them and run it.
254
692
 
255
- onnx_results: Dict[str, Any] = {}
256
- for init in onx.graph.initializer: # type: ignore
257
- positions[init.name] = -1
258
- onnx_results[init.name] = to_array_extended(init)
259
- param_name = f"p_{init.name.replace('.', '_')}"
260
- if param_name == init.name:
261
- continue
262
- assert param_name not in onnx_results, (
263
- f"Some confusion may happen because {init.name!r} -> {param_name!r} "
264
- f"and onnx_results has {sorted(onnx_results)}"
693
+ .. runpython::
694
+ :showcode:
695
+ :warningout: UserWarning
696
+
697
+ import pandas
698
+ import onnx
699
+ import torch
700
+ from onnx_diagnostic.torch_onnx.sbs import run_aligned
701
+ from onnx_diagnostic.reference import OnnxruntimeEvaluator
702
+
703
+
704
+ ep = torch.export.load("test_doc_sbs_example.pt2")
705
+ onx = onnx.load("test_doc_sbs_example.onnx")
706
+ inputs = torch.load("test_doc_sbs_example.pt")
707
+
708
+
709
+ results = list(
710
+ run_aligned(
711
+ ep,
712
+ onx,
713
+ OnnxruntimeEvaluator,
714
+ inputs,
715
+ atol=1e-5,
716
+ rtol=1e-5,
717
+ verbose=1,
718
+ use_tensor=True,
719
+ )
265
720
  )
266
- onnx_results[param_name] = onnx_results[init.name]
721
+ print("------------")
722
+ print("final results")
723
+ df = pandas.DataFrame(results)
724
+ df = df.apply(lambda col: col.fillna("") if col.dtype == "object" else col)
725
+ print(df)
726
+
727
+ A command line can also be run:
728
+
729
+ .. code-block:: bash
267
730
 
268
- torch_results: Dict[str, Any] = {
269
- k: torch.from_numpy(v.copy())
270
- for k, v in onnx_results.items()
271
- if not k.startswith("init")
731
+ python -m onnx_diagnostic sbs -i <tensors>.input.pt \\
732
+ --ep <exported_program>.pt2 \\
733
+ -m <model>.onnx \\
734
+ -o results.xlsx \\
735
+ -v 1 --atol=0.1 --rtol=1
736
+ """
737
+ assert callable(run_cls), f"run_cls={run_cls} not a callable"
738
+ already_yielded = {}
739
+ reset_names = set(reset_names) if reset_names else set()
740
+ str_kws = dict(with_shape=True, with_device=True)
741
+ has_cuda = any(
742
+ (isinstance(t, torch.Tensor) and t.is_cuda)
743
+ for t in flatten_object([args, kwargs], drop_keys=True)
744
+ )
745
+ default_device = None
746
+ if has_cuda:
747
+ for t in flatten_object([args, kwargs], drop_keys=True):
748
+ if t is not None and t.is_cuda:
749
+ default_device = t.device
750
+ break
751
+ run_cls_kwargs = {
752
+ "ir_version": onx.ir_version,
753
+ "opsets": {d.domain: d.version for d in onx.opset_import},
754
+ "verbose": max(verbose - 1, 0),
755
+ "providers": (
756
+ ["CUDAExecutionProvider", "CPUExecutionProvider"]
757
+ if has_cuda
758
+ else ["CPUExecutionProvider"]
759
+ ),
272
760
  }
273
- last_position = 0
761
+ run_cls_kwargs = {
762
+ k: v
763
+ for k, v in run_cls_kwargs.items()
764
+ if k in set(inspect.signature(run_cls).parameters)
765
+ }
766
+ if verbose:
767
+ print(f"[run_aligned] run_cls={run_cls}")
768
+ print(f"[run_aligned] run_cls_kwargs={run_cls_kwargs}")
769
+ if replay_configuration:
770
+ print(f"[run_aligned] replay={replay_configuration}")
771
+
772
+ # preparation with ep.graph.nodes
773
+ ep_state_dict = {**ep.state_dict, **dict(ep.named_buffers(), **ep.tensor_constants)}
774
+ placeholders_to_state_dict = {
775
+ **{f"p_{name.replace('.', '_')}": name for name in ep.state_dict},
776
+ **{f"b_{name.replace('.', '_')}": name for name, _ in ep.named_buffers()},
777
+ **{f"c_{name.replace('.', '_')}": name for name in ep.tensor_constants},
778
+ }
779
+ skip_mapping_torch_onnx = _duplicated_values(placeholders_to_state_dict)
780
+ placeholders = {}
781
+ if verbose:
782
+ print(f"[run_aligned] ep: model has {len(ep_state_dict)} torch constants or weights.")
783
+
784
+ if verbose:
785
+ print(f"[run_aligned] ep: walks through {len(ep.graph.nodes)} nodes from torch")
786
+
787
+ # dictionary mapping result names and their position in both graphs.
788
+ positions: Dict[str, Dict[str, int]] = {}
789
+ ep_graph_nodes = list(ep.graph.nodes)
790
+ torch_results: Dict[str, Any] = {}
274
791
  torch_output_names = None
275
- for node in ep.graph.nodes:
276
- if node.op == "output":
277
- torch_output_names = [n.name for n in node.args[0]]
278
- onnx_outputs_names = [o.name for o in onx.graph.output]
279
- assert torch_output_names is not None and len(torch_output_names) == len(
280
- onnx_outputs_names
281
- ), (
282
- f"Unexpected number of outputs, torch_output_names={torch_output_names}, "
283
- f"onnx_outputs_names={onnx_outputs_names}"
792
+ torch_input_names: List[str] = []
793
+ name_to_ep_node = {}
794
+ torch_names_to_onnx_names = {}
795
+ torch_output_names = _preparation_with_fx_graph(
796
+ ep_graph_nodes,
797
+ name_to_ep_node,
798
+ torch_input_names,
799
+ onx,
800
+ torch_names_to_onnx_names,
801
+ skip_mapping_torch_onnx,
802
+ torch_results,
803
+ placeholders,
804
+ placeholders_to_state_dict,
805
+ ep_state_dict,
806
+ positions,
284
807
  )
285
- mapping_onnx_to_torch = dict(zip(onnx_outputs_names, torch_output_names))
286
808
 
809
+ # prepration for onnx
287
810
  if verbose:
288
- for k, v in torch_results.items():
289
- print(
290
- f"[run_aligned] +torch-cst: {k}: "
291
- f"{string_type(v, with_shape=True, with_min_max=True)}"
292
- )
293
- for k, v in onnx_results.items():
294
- print(
295
- f"[run_aligned] +onnx-init: {k}: "
296
- f"{string_type(v, with_shape=True, with_min_max=True)}"
297
- )
811
+ print(f"[run_aligned] ep: found {len(torch_results)} torch constants or weights.")
812
+ print(f"[run_aligned] ep: found inputs {torch_input_names}")
813
+ print(f"[run_aligned] ep: found outputs {torch_output_names}")
814
+ print(f"[run_aligned] nx: walks through {len(onx.graph.node)} nodes from onnx")
298
815
 
299
- for inp, v in zip(onx.graph.input, args):
300
- onnx_results[inp.name] = to_numpy(v)
301
- if verbose:
302
- print(
303
- f"[run_aligned] +onnx-input: {inp.name}: "
304
- f"{string_type(v, with_shape=True, with_min_max=True)}"
305
- )
816
+ mapping_onnx_to_torch, onnx_results, memory_cpu, memory_cuda, records_to_yield = (
817
+ _preparation_with_onnx_model(
818
+ default_device,
819
+ use_tensor,
820
+ onx,
821
+ already_yielded,
822
+ str_kws,
823
+ positions,
824
+ torch_names_to_onnx_names,
825
+ torch_output_names,
826
+ torch_results,
827
+ skip_mapping_torch_onnx,
828
+ verbose,
829
+ args,
830
+ kwargs,
831
+ )
832
+ )
833
+ for record in records_to_yield:
834
+ yield record
306
835
 
307
- for i, node in enumerate(ep.graph.nodes):
308
- if verbose:
836
+ if verbose:
837
+ print(f"[run_aligned] nx: handled {len(onnx_results)} initializers from onnx")
838
+ print(f"[run_aligned] nx: memory cpu {memory_cpu / 2**20:.3f} Mb")
839
+ print(f"[run_aligned] nx: memory cuda {memory_cuda / 2**20:.3f} Mb")
840
+ print(f"[run_aligned] nx: {len(onnx_results)} constants")
841
+ print(f"[run_aligned] nx: {len(onx.graph.input)} inputs")
842
+ print(f"[run_aligned] nx: {len(onx.graph.output)} outputs")
843
+ print(f"[run_aligned] bo: {len(mapping_onnx_to_torch)} outputs")
844
+ print(f"[run_aligned] run_cls_kwargs={run_cls_kwargs}")
845
+ if verbose > 1:
846
+ for k, v in torch_results.items():
847
+ print(f"[run_aligned-ep] +cst: {k}: {string_type(v, **str_kws)}")
848
+ for k, v in onnx_results.items():
849
+ print(f"[run_aligned-nx] +ini: {k}: {string_type(v, **str_kws)}")
850
+
851
+ # starts the side-by-side
852
+ if verbose:
853
+ print(
854
+ f"[run_aligned] ep: starts side-by-side with {len(ep_graph_nodes)} "
855
+ f"fx nodes and {len(onx.graph.node)} onnx nodes"
856
+ )
857
+ if verbose == 1:
858
+ import tqdm
859
+
860
+ loop = tqdm.tqdm(total=len(ep_graph_nodes) + len(onx.graph.node))
861
+ else:
862
+ loop = None
863
+
864
+ already_run: Set[int] = set()
865
+ ep_durations = {}
866
+ status = StatusRunAligned()
867
+ last_position = 0
868
+ for i_torch, node in enumerate(ep_graph_nodes):
869
+ if verbose > 1:
309
870
  if node.op == "call_function":
310
871
  print(
311
- f"[run_aligned] run ep.graph.nodes[{i}]: "
872
+ f"[run_aligned] run ep.graph.nodes[{i_torch}]: "
312
873
  f"{node.op}[{node.target}] -> {node.name!r}"
313
874
  )
314
875
  else:
315
- print(f"[run_aligned] run ep.graph.nodes[{i}]: {node.op} -> {node.name!r}")
876
+ print(
877
+ f"[run_aligned] run ep.graph.nodes[{i_torch}]: {node.op} -> {node.name!r}"
878
+ )
879
+ elif verbose == 1:
880
+ loop.set_description(
881
+ f"ep {i_torch}/{len(ep_graph_nodes)} nx {last_position}/{len(onx.graph.node)} "
882
+ f"{status.to_str()}"
883
+ )
884
+ loop.update(min(1, 1 + i_torch + last_position))
316
885
 
317
886
  if node.op == "placeholder":
318
- if node.name in onnx_results:
319
- torch_results[node.name] = torch.from_numpy(onnx_results[node.name].copy())
320
- if verbose:
321
- t = torch_results[node.name]
322
- print(
323
- f"[run_aligned] +torch {node.name}="
324
- f"{string_type(t, with_shape=True, with_min_max=True)}"
887
+ is_input = node.name not in placeholders
888
+ if is_input:
889
+ torch_results[node.name] = (
890
+ onnx_results[torch_names_to_onnx_names[node.name]]
891
+ if use_tensor
892
+ else from_numpy(onnx_results[torch_names_to_onnx_names[node.name]])
893
+ )
894
+ assert isinstance(torch_results[node.name], torch.Tensor), (
895
+ f"torch_results[{node.name}] not a tensor but "
896
+ f"{type(torch_results[node.name])}, use_tensor={use_tensor}"
897
+ )
898
+ t = torch_results[node.name]
899
+ if verbose > 1:
900
+ print(f"[run_aligned-ep] =ags: {node.name}={string_type(t, **str_kws)}")
901
+ # Otherwise, it is an input.
902
+ record = RunAlignedRecord(
903
+ ep_id_node=i_torch,
904
+ onnx_id_node=-1,
905
+ ep_name=node.name,
906
+ onnx_name=torch_names_to_onnx_names[node.name],
907
+ ep_target="input",
908
+ onnx_op_type="input",
909
+ ep_shape_type=string_type(t, **str_kws),
910
+ onnx_shape_type=string_type(
911
+ onnx_results[torch_names_to_onnx_names[node.name]], **str_kws
912
+ ),
913
+ )
914
+ yield record.check(already_yielded)
915
+ else:
916
+ assert node.name in placeholders_to_state_dict, (
917
+ f"Unable to find placeholder {node.name!r} (node.op={node.op!r}), "
918
+ f"existing: {sorted(placeholders_to_state_dict)}"
919
+ )
920
+ assert node.name in torch_results, (
921
+ f"placeholder {node.name!r} (node.op={node.op!r}), "
922
+ f"should have been added to torch_results: {sorted(torch_results)}"
923
+ )
924
+ t = torch_results[node.name]
925
+ if (
926
+ node.name in torch_names_to_onnx_names
927
+ and node.name not in skip_mapping_torch_onnx
928
+ ):
929
+ if verbose > 1:
930
+ print(
931
+ f"[run_aligned-ep] =plh: "
932
+ f"{node.name}={string_type(t, **str_kws)}"
933
+ )
934
+ record = RunAlignedRecord(
935
+ ep_id_node=i_torch,
936
+ onnx_id_node=-1,
937
+ ep_name=node.name,
938
+ onnx_name=torch_names_to_onnx_names[node.name],
939
+ ep_target="placeholder",
940
+ onnx_op_type="initializer",
941
+ ep_shape_type=string_type(t, **str_kws),
942
+ onnx_shape_type=string_type(
943
+ onnx_results[torch_names_to_onnx_names[node.name]], **str_kws
944
+ ),
325
945
  )
326
- continue
327
- raise AssertionError(
328
- f"unable to process node {node.op} -> {node.name!r} "
329
- f"not in {sorted(onnx_results)}, len(args)={len(args)}, "
330
- f"onx.graph.input={[i.name for i in onx.graph.input]}"
331
- )
946
+ if not is_input:
947
+ record.set_diff(
948
+ max_diff(
949
+ t,
950
+ onnx_results[torch_names_to_onnx_names[node.name]],
951
+ hist=[0.1, 0.01],
952
+ )
953
+ )
954
+ yield record.check(already_yielded)
955
+ else:
956
+ if verbose > 1:
957
+ print(
958
+ f"[run_aligned-ep] +plh: {node.name}={string_type(t, **str_kws)}"
959
+ )
960
+ yield RunAlignedRecord(
961
+ ep_id_node=i_torch,
962
+ ep_name=node.name,
963
+ ep_target="placeholder",
964
+ ep_shape_type=string_type(t, **str_kws),
965
+ ).check(already_yielded)
966
+ continue
332
967
 
333
968
  outputs = [node.name] if isinstance(node.name, str) else list(node.name)
334
969
  args, kwargs = prepare_args_kwargs(torch_results, node)
970
+ begin = time.perf_counter()
335
971
  new_outputs = run_fx_node(node, args, kwargs)
336
- if isinstance(new_outputs, (torch.Tensor, int, float, list)):
972
+ duration = time.perf_counter() - begin
973
+ ep_durations[i_torch] = duration
974
+ if isinstance(new_outputs, (torch.Tensor, int, float, list, tuple)):
337
975
  new_outputs = (new_outputs,)
338
976
 
339
977
  if new_outputs is None:
@@ -342,99 +980,112 @@ def run_aligned(
342
980
 
343
981
  for k, v in zip(outputs, new_outputs):
344
982
  torch_results[k] = v
345
- if verbose:
983
+ if verbose > 1:
346
984
  for k, v in zip(outputs, new_outputs):
347
- print(
348
- f"[run_aligned] +torch {k}="
349
- f"{string_type(v, with_shape=True, with_min_max=True)}"
350
- )
985
+ print(f"[run_aligned-ep] +res: {k}={string_type(v, **str_kws)}")
351
986
 
352
987
  max_pos = -2
353
988
  for n in outputs:
354
- if n in positions and "onnx" in positions[n]:
355
- max_pos = max(max_pos, positions[n]["onnx"])
989
+ if n in positions:
990
+ if "onnx" in positions[n]:
991
+ max_pos = max(max_pos, positions[n]["onnx"])
992
+ if "fx" in positions[n]:
993
+ if positions[n]["fx"] > i_torch:
994
+ max_pos = -2
995
+ break
356
996
  if max_pos == -2:
357
997
  # we skip.
358
998
  continue
359
999
 
1000
+ next_to_visit = last_position
360
1001
  for i_onnx in range(last_position, max_pos + 1):
1002
+ if i_onnx in already_run:
1003
+ continue
1004
+ # The onnx node may produce more than one output, in that
1005
+ # case, we need to check the exported program is not behind.
361
1006
  node = onx.graph.node[i_onnx]
362
- if verbose:
363
- print(
364
- f"[run_aligned] run onx.graph.node[{i_onnx}]: "
365
- f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
366
- )
367
- ref = cls(node)
368
- feeds = {k: onnx_results[k] for k in node.input}
369
- res = ref.run(None, feeds)
370
- for o, r in zip(node.output, res):
371
- onnx_results[o] = r
372
- if verbose:
373
- print(
374
- f"[run_aligned] +onnx {o}="
375
- f"{string_type(r, with_shape=True, with_min_max=True)}"
376
- )
1007
+ ep_behind = False
1008
+ for iname in node.output:
1009
+ if iname in positions and "fx" in positions[iname]:
1010
+ if positions[iname]["fx"] > i_torch:
1011
+ ep_behind = True
1012
+ break
1013
+ if ep_behind:
1014
+ break
377
1015
 
378
- to = mapping_onnx_to_torch.get(o, o)
379
- if to in torch_results:
380
- d = max_diff(torch_results[to], r)
381
- if verbose:
382
- if o == to:
383
- print(f"[run_aligned] =common results {to}: {string_diff(d)}")
384
- else:
385
- print(f"[run_aligned] =common results {to}/{o}: {string_diff(d)}")
386
- if not (
387
- atol is None
388
- or rtol is None
389
- or (d["abs"] <= atol and d["rel"] <= rtol)
390
- ):
391
- skw = dict(with_shape=True, with_min_max=True)
392
- raise ValueError(
393
- f"discrepancies detected for results [{to}/{o}]: "
394
- f"{string_diff(d)}"
395
- f"\n-- torch_results: {string_type(torch_results[to], **skw)}"
396
- f"\n-- onnx_results: {string_type(r, **skw)}"
397
- f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}"
398
- )
399
- yield (i, i_onnx, o, to, d)
1016
+ for r in _loop_onnx_node(
1017
+ onx,
1018
+ ep_graph_nodes,
1019
+ onnx_results,
1020
+ mapping_onnx_to_torch,
1021
+ torch_results,
1022
+ ep_durations,
1023
+ use_tensor,
1024
+ i_torch,
1025
+ i_onnx,
1026
+ name_to_ep_node,
1027
+ run_cls_kwargs,
1028
+ str_kws,
1029
+ status,
1030
+ already_run,
1031
+ torch_names_to_onnx_names,
1032
+ verbose,
1033
+ exc,
1034
+ atol,
1035
+ rtol,
1036
+ reset_names,
1037
+ replay_configuration,
1038
+ has_cuda,
1039
+ run_cls,
1040
+ loop,
1041
+ run_onnx_with_torch_inputs,
1042
+ ):
1043
+ if r:
1044
+ yield r.check(already_yielded)
1045
+ next_to_visit = i_onnx + 1
400
1046
 
401
- last_position = max_pos + 1
1047
+ last_position = next_to_visit
402
1048
 
403
1049
  # complete the execution of the onnx graph
1050
+ if verbose > 1:
1051
+ print(
1052
+ f"[run_aligned] complete execution of onnx graph from pos={last_position} "
1053
+ f"to {len(onx.graph.node)}"
1054
+ )
404
1055
  for i_onnx in range(last_position, len(onx.graph.node)):
405
- node = onx.graph.node[i_onnx]
406
- if verbose:
407
- print(
408
- f"[run_aligned] run onx.graph.node[{i_onnx}]: "
409
- f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
410
- )
411
- ref = cls(node)
412
- feeds = {k: onnx_results[k] for k in node.input}
413
- res = ref.run(None, feeds)
414
- for o, r in zip(node.output, res):
415
- onnx_results[o] = r
416
- if verbose:
417
- print(
418
- f"[run_aligned] +onnx {o}="
419
- f"{string_type(r, with_shape=True, with_min_max=True)}"
420
- )
1056
+ if i_onnx in already_run:
1057
+ continue
1058
+ for r in _loop_onnx_node(
1059
+ onx,
1060
+ ep_graph_nodes,
1061
+ onnx_results,
1062
+ mapping_onnx_to_torch,
1063
+ torch_results,
1064
+ ep_durations,
1065
+ use_tensor,
1066
+ i_torch,
1067
+ i_onnx,
1068
+ name_to_ep_node,
1069
+ run_cls_kwargs,
1070
+ str_kws,
1071
+ status,
1072
+ already_run,
1073
+ torch_names_to_onnx_names,
1074
+ verbose,
1075
+ exc,
1076
+ atol,
1077
+ rtol,
1078
+ reset_names,
1079
+ replay_configuration,
1080
+ has_cuda,
1081
+ run_cls,
1082
+ loop,
1083
+ run_onnx_with_torch_inputs,
1084
+ ):
1085
+ if r:
1086
+ yield r.check(already_yielded)
421
1087
 
422
- to = mapping_onnx_to_torch.get(o, o)
423
- if to in torch_results:
424
- d = max_diff(torch_results[to], r)
425
- if verbose:
426
- if o == to:
427
- print(f"[run_aligned] =common results* {to}: {string_diff(d)}")
428
- else:
429
- print(f"[run_aligned] =common results* {to}/{o}: {string_diff(d)}")
430
- if not (
431
- atol is None or rtol is None or (d["abs"] <= atol and d["rel"] <= rtol)
432
- ):
433
- skw = dict(with_shape=True, with_min_max=True)
434
- raise ValueError(
435
- f"discrepancies detected for results* [{to}/{o}]: {string_diff(d)}"
436
- f"\n-- torch_results: {string_type(torch_results[to], **skw)}"
437
- f"\n-- onnx_results: {string_type(r, **skw)}"
438
- f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}"
439
- )
440
- yield (i, i_onnx, o, to, d)
1088
+ if loop is not None:
1089
+ loop.close()
1090
+ if verbose:
1091
+ print(f"[run_aligned] done with status={status.to_str()}")