onnx-diagnostic 0.8.2__py3-none-any.whl → 0.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +412 -12
  3. onnx_diagnostic/export/api.py +111 -8
  4. onnx_diagnostic/export/control_flow.py +48 -345
  5. onnx_diagnostic/export/control_flow_onnx.py +528 -0
  6. onnx_diagnostic/export/control_flow_research.py +12 -7
  7. onnx_diagnostic/export/onnx_plug.py +531 -0
  8. onnx_diagnostic/ext_test_case.py +163 -48
  9. onnx_diagnostic/helpers/cache_helper.py +1 -1
  10. onnx_diagnostic/helpers/dot_helper.py +222 -0
  11. onnx_diagnostic/helpers/helper.py +108 -37
  12. onnx_diagnostic/helpers/mini_onnx_builder.py +3 -1
  13. onnx_diagnostic/helpers/model_builder_helper.py +27 -0
  14. onnx_diagnostic/helpers/onnx_helper.py +531 -6
  15. onnx_diagnostic/helpers/ort_session.py +45 -19
  16. onnx_diagnostic/helpers/torch_fx_graph_helper.py +164 -0
  17. onnx_diagnostic/helpers/torch_helper.py +131 -8
  18. onnx_diagnostic/reference/ort_evaluator.py +228 -46
  19. onnx_diagnostic/tasks/feature_extraction.py +15 -14
  20. onnx_diagnostic/tasks/summarization.py +72 -137
  21. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_attention.py +236 -0
  22. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_cache_utils.py +50 -0
  23. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_causal_mask.py +89 -0
  24. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_dynamic_cache.py +177 -0
  25. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_gemma3.py +54 -0
  26. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_generation_mixin.py +486 -0
  27. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_idefics.py +156 -0
  28. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_masking_utils.py +173 -0
  29. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2.py +99 -0
  30. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen2_5.py +735 -0
  31. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_qwen3.py +106 -0
  32. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_rotary_embedding.py +412 -0
  33. onnx_diagnostic/torch_export_patches/patches/_patch_transformers_sam_mask_decoder.py +132 -0
  34. onnx_diagnostic/torch_export_patches/patches/patch_helper.py +28 -0
  35. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +64 -2608
  36. onnx_diagnostic/torch_models/code_sample.py +2 -1
  37. onnx_diagnostic/torch_models/hghub/model_inputs.py +34 -7
  38. onnx_diagnostic/torch_models/validate.py +64 -2
  39. onnx_diagnostic/torch_onnx/runtime_info.py +1 -24
  40. onnx_diagnostic/torch_onnx/sbs.py +969 -312
  41. onnx_diagnostic/torch_onnx/sbs_dataclasses.py +535 -0
  42. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/METADATA +1 -1
  43. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/RECORD +46 -27
  44. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/WHEEL +0 -0
  45. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  46. {onnx_diagnostic-0.8.2.dist-info → onnx_diagnostic-0.8.4.dist-info}/top_level.txt +0 -0
@@ -1,158 +1,586 @@
1
- from typing import Any, Dict, Iterator, Optional, Tuple, Union
1
+ import inspect
2
+ import time
3
+ from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Sequence, Tuple, Union
2
4
  import onnx
5
+ import onnx.helper as oh
6
+ import numpy as np
3
7
  import torch
4
- from ..helpers import string_type, string_diff, max_diff
5
- from ..helpers.onnx_helper import to_array_extended
6
- from ..helpers.torch_helper import to_numpy
8
+ from ..helpers import string_type, string_diff, max_diff, flatten_object
9
+ from ..helpers.onnx_helper import pretty_onnx
10
+ from ..helpers.torch_helper import (
11
+ to_numpy,
12
+ from_numpy,
13
+ to_tensor,
14
+ torch_dtype_to_onnx_dtype,
15
+ )
16
+ from ..helpers.torch_fx_graph_helper import prepare_args_kwargs, run_fx_node
17
+ from ..reference.ort_evaluator import OnnxList, OnnxruntimeEvaluator
18
+ from .sbs_dataclasses import (
19
+ ReplayConfiguration,
20
+ RunAlignedRecord,
21
+ StatusRunAligned,
22
+ make_torch_inputs,
23
+ )
7
24
 
8
25
 
9
- def validate_fx_tensor(
10
- node: torch.fx.Node, tensor: torch.Tensor, expected_shape: Tuple[Any, ...]
11
- ) -> None:
12
- """
13
- Validates the shape of tensor is expected.
26
+ def _check_tensor_(use_tensor, name, obj, flip_type=False):
27
+ if flip_type:
28
+ if use_tensor:
29
+ if isinstance(obj, np.ndarray):
30
+ obj = from_numpy(obj)
31
+ else:
32
+ if isinstance(obj, torch.Tensor):
33
+ obj = to_numpy(obj)
14
34
 
15
- :param node: node
16
- :param tensor: tensor
17
- :param expected_shape: expected shape
18
- """
19
- assert len(tensor.shape) == len(expected_shape), (
20
- f"Shape mismatch, got {tensor.shape} expected {expected_shape}, "
21
- f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
22
- f"node.args={node.args}, node.kwargs={node.kwargs}, "
23
- f"node.meta={node.meta}"
35
+ assert not use_tensor or isinstance(obj, (torch.Tensor, OnnxList)), (
36
+ f"Unexpected type {type(obj)} for {name!r}. "
37
+ f"use_tensor is True so torch.Tensor is expected."
38
+ )
39
+ assert use_tensor or isinstance(obj, (np.ndarray, OnnxList)), (
40
+ f"Unexpected type {type(obj)} for {name!r}. "
41
+ f"use_tensor is False so np.array is expected."
24
42
  )
25
- for a, b in zip(tensor.shape, expected_shape):
26
- assert not isinstance(b, int) or a == b or {a, b} == {0, 1}, (
27
- f"Dimension mismatch, got {tensor.shape} expected {expected_shape}, "
28
- f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
29
- f"node.args={node.args}, node.kwargs={node.kwargs}, "
30
- f"node.meta={node.meta}"
43
+ return obj
44
+
45
+
46
+ def _make_node_from_initializer(proto: onnx.TensorProto) -> onnx.NodeProto:
47
+ return oh.make_node("Constant", [], [proto.name], value=proto)
48
+
49
+
50
+ def _loop_cmp(
51
+ mapping_onnx_to_torch: Dict[str, str],
52
+ torch_results: Dict[str, torch.Tensor],
53
+ onnx_results: Dict[str, Any],
54
+ onnx_name: str,
55
+ onnx_result: torch.Tensor,
56
+ second_onnx_result: torch.Tensor,
57
+ verbose: int,
58
+ atol: Optional[float],
59
+ rtol: Optional[float],
60
+ i_torch: int,
61
+ i_onnx: int,
62
+ str_kws: Dict[str, bool],
63
+ exc: bool,
64
+ use_tensor: bool,
65
+ ) -> Optional[RunAlignedRecord]:
66
+ onnx_results[onnx_name] = _check_tensor_(use_tensor, onnx_name, onnx_result)
67
+ if verbose > 1:
68
+ print(f"[run_aligned-nx] +res: {onnx_name}={string_type(onnx_result, **str_kws)}")
69
+
70
+ to = mapping_onnx_to_torch.get(onnx_name, onnx_name)
71
+ if to in torch_results:
72
+ d = max_diff(torch_results[to], onnx_result, hist=[0.1, 0.01])
73
+ if verbose > 1:
74
+ if onnx_name == to:
75
+ print(f"[run_aligned-==] cmp {to}: {string_diff(d)}")
76
+ else:
77
+ print(f"[run_aligned-~~] cmd {to}/{onnx_name}: {string_diff(d)}")
78
+ if not (atol is None or rtol is None or (d["abs"] <= atol and d["rel"] <= rtol)):
79
+ if exc:
80
+ raise ValueError(
81
+ f"discrepancies detected for results [{to}/{onnx_name}]: "
82
+ f"{string_diff(d)}"
83
+ f"\n-- onnx_result: {string_type(onnx_result[to], **str_kws)}"
84
+ f"\n-- onnx_results: {string_type(onnx_result, **str_kws)}"
85
+ f"\n-- torch\n{onnx_result[to]}"
86
+ )
87
+ else:
88
+ print(
89
+ f"[run_align-dx] discrepancies {string_diff(d)} - [{to}/{onnx_name}]"
90
+ )
91
+ r = RunAlignedRecord(
92
+ ep_id_node=i_torch,
93
+ onnx_id_node=i_onnx,
94
+ ep_name=to,
95
+ onnx_name=onnx_name,
96
+ ep_shape_type=string_type(torch_results[to], **str_kws),
97
+ onnx_shape_type=string_type(onnx_result, **str_kws),
31
98
  )
99
+ r.set_diff(d)
100
+ if second_onnx_result is not None:
101
+ d2 = max_diff(torch_results[to], second_onnx_result, hist=[0.1, 0.01])
102
+ r.set_diff2(d2)
103
+ mapping_onnx_to_torch[onnx_name] = to
104
+ return r
105
+ return None
32
106
 
33
107
 
34
- def validate_fx_outputs(node: torch.fx.Node, outputs: Tuple[Any, ...]) -> None:
35
- """
36
- Validates the outputs of a node using metadata stored in the node.
108
+ def _duplicated_values(d):
109
+ rev = {}
110
+ for k, v in d.items():
111
+ if v in rev:
112
+ rev[v].append(k)
113
+ else:
114
+ rev[v] = [k]
115
+ res = {k: v for k, v in rev.items() if len(v) > 1}
116
+ final = set()
117
+ for v in res.values():
118
+ final |= set(v)
119
+ return final
37
120
 
38
- :param node: node
39
- :param outputs: outputs
40
- """
41
- if "val" not in node.meta:
42
- return
43
- if isinstance(outputs, torch.Tensor):
44
- validate_fx_tensor(node, outputs, node.meta["val"].shape)
45
- return
46
- if isinstance(outputs, (tuple, list)):
47
- assert isinstance(node.meta["val"], (list, tuple)), (
48
- f"Unexpected type {string_type(node.meta['val'])} for node.meta['val'], "
49
- f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
50
- f"node.args={node.args}, node.kwargs={node.kwargs}, "
51
- f"node.meta={node.meta}"
121
+
122
+ def _validation_nn_functional(
123
+ node: onnx.NodeProto, new_feeds: Dict[str, torch.Tensor], expected: List[torch.Tensor]
124
+ ) -> Optional[str]:
125
+ if node.op_type == "Gemm" and len(node.input) == 3:
126
+ atts = {}
127
+ for att in node.attribute:
128
+ if att.name in ("alpha", "beta"):
129
+ atts[att.name] = att.f
130
+ elif att.name in ("transA", "transB"):
131
+ atts[att.name] = att.i
132
+ if atts == {"transB": 1}:
133
+ res = torch.nn.functional.linear(*[new_feeds[i] for i in node.input])
134
+ diff = max_diff(res, expected[0])
135
+ return f"function.linear:{string_diff(diff)}"
136
+ return None
137
+
138
+
139
+ def _loop_onnx_node(
140
+ onx: onnx.ModelProto,
141
+ ep_graph_nodes: List[torch.fx.Node],
142
+ onnx_results: Dict[str, Any],
143
+ mapping_onnx_to_torch: Dict[str, str],
144
+ torch_results: Dict[str, torch.Tensor],
145
+ ep_durations,
146
+ use_tensor: bool,
147
+ i_torch: int,
148
+ i_onnx: int,
149
+ name_to_ep_node: Dict[str, int],
150
+ run_cls_kwargs: Dict[str, Any],
151
+ str_kws: Dict[str, bool],
152
+ status: StatusRunAligned,
153
+ already_run_onnx: Set[int],
154
+ torch_names_to_onnx_names: Dict[str, str],
155
+ verbose: int,
156
+ exc: bool,
157
+ atol: float,
158
+ rtol: float,
159
+ reset_names: Set[str],
160
+ replay_configuration: Optional[ReplayConfiguration],
161
+ has_cuda: bool,
162
+ run_cls: type,
163
+ loop: Any,
164
+ run_onnx_with_torch_inputs: bool,
165
+ ) -> Iterator[Optional[RunAlignedRecord]]:
166
+
167
+ if i_onnx in already_run_onnx:
168
+ yield None
169
+ node = onx.graph.node[i_onnx]
170
+ if verbose > 1:
171
+ print(
172
+ f"[run_aligned] run onx.graph.node[{i_onnx}]: "
173
+ f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
52
174
  )
53
- assert len(outputs) == len(node.meta["val"]), (
54
- f"Length mismatch, got {len(outputs)} expected {len(node.meta['val'])}, "
55
- f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
56
- f"node.args={node.args}, node.kwargs={node.kwargs}, "
57
- f"node.meta={node.meta}"
175
+ elif verbose == 1:
176
+ loop.set_description(
177
+ f"ep {i_torch}/{len(ep_graph_nodes)} nx {i_onnx}/{len(onx.graph.node)} "
178
+ f"{status.to_str()}"
58
179
  )
59
- for a, b in zip(outputs, node.meta["val"]):
60
- validate_fx_tensor(node, a, b.shape)
61
- return
62
- if isinstance(outputs, int):
63
- assert (
64
- isinstance(node.meta["val"], (torch.SymInt, torch.SymBool, torch.SymFloat))
65
- or outputs == node.meta["val"]
66
- ), (
67
- f"Int mismatch, got {outputs} expected {node.meta['val']}, "
68
- f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
69
- f"node.args={node.args}, node.kwargs={node.kwargs}, "
70
- f"node.meta={node.meta}"
180
+ loop.update(min(1, 1 + i_torch + i_onnx))
181
+
182
+ ref = run_cls(node, **run_cls_kwargs)
183
+ # We need to clone because the runtime maybe using dlpack to create OrtValue
184
+ hidden_inputs = OnnxruntimeEvaluator._get_hidden_node_inputs(node)
185
+ all_inputs = [*node.input, *hidden_inputs] if hidden_inputs else node.input
186
+ feeds = (
187
+ {k: onnx_results[k].clone() for k in all_inputs if k}
188
+ if use_tensor
189
+ else {k: onnx_results[k].copy() for k in all_inputs if k}
190
+ )
191
+ assert "" not in feeds, f"Unexpected feeds={string_type(feeds, **str_kws)}"
192
+ if verbose > 1:
193
+ print(f"[run_aligned] feeds={string_type(feeds, **str_kws)}")
194
+ begin = time.perf_counter()
195
+ try:
196
+ res = ref.run(None, feeds) # type: ignore[attr-defined]
197
+ except Exception as e:
198
+ raise RuntimeError(
199
+ f"Unable to run node {node.op_type}, domain={node.domain} "
200
+ f"with inputs={node.input}, feeds={string_type(feeds, **str_kws)}"
201
+ ) from e
202
+ duration = time.perf_counter() - begin
203
+ if verbose > 1:
204
+ print(f"[run_aligned] res={string_type(res, **str_kws)}")
205
+ assert (
206
+ not has_cuda
207
+ or not any(t is not None and t.is_cuda for t in feeds.values())
208
+ or any(t is not None and t.is_cuda for t in res)
209
+ or node.op_type in {"Shape", "Size"} # on CPU no matter what
210
+ or node.op_type
211
+ in {
212
+ "Add",
213
+ "Concat",
214
+ "Div",
215
+ "Gather",
216
+ "Mul",
217
+ "Range",
218
+ "Squeeze",
219
+ "Sub",
220
+ "Unsqueeze",
221
+ } # not sure, could be about shapes
222
+ ), (
223
+ f"One input is on cuda but there is no float output on cuda, "
224
+ f"feeds={string_type(feeds, with_device=True, with_shape=True)}, "
225
+ f"res={string_type(res, with_device=True, with_shape=True)}, "
226
+ f"node is {pretty_onnx(node)}"
227
+ )
228
+
229
+ comment = None
230
+ cross = None
231
+ if run_onnx_with_torch_inputs:
232
+ # Let's run the operator with torch results if they are available
233
+ new_feeds, removed = make_torch_inputs(
234
+ node.input,
235
+ {
236
+ **{v: k for k, v in torch_names_to_onnx_names.items()},
237
+ **mapping_onnx_to_torch,
238
+ },
239
+ onnx_results,
240
+ torch_results,
241
+ submodel=None,
71
242
  )
72
- return
73
- if outputs is None:
74
- assert node.meta["val"] is None, (
75
- f"None mismatch, got {outputs} expected {node.meta['val']}, "
76
- f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
77
- f"node.args={node.args}, node.kwargs={node.kwargs}, "
78
- f"node.meta={node.meta}"
243
+ if not removed:
244
+ if verbose > 1:
245
+ print(
246
+ f"[run_aligned] feeds for second run="
247
+ f"{string_type(new_feeds, **str_kws)}"
248
+ )
249
+ cross = ref.run(None, new_feeds)
250
+ if verbose > 1:
251
+ print(f"[run_aligned] got for second run={string_type(cross, **str_kws)}")
252
+ # Gemm = torch.nn.function.linear, in that case, we just run it as well
253
+ to = mapping_onnx_to_torch.get(node.output[0], node.output[0])
254
+ if to in torch_results:
255
+ comment = _validation_nn_functional(node, new_feeds, [torch_results[to]])
256
+ elif verbose > 1:
257
+ print(f"[run_aligned] second run not possible because of missing {removed}")
258
+
259
+ if cross is None:
260
+ cross = [None for _ in res]
261
+
262
+ list_node_output = list(node.output)
263
+ node_output = [o for o in list_node_output if o]
264
+ for o, r, r2 in zip(node_output, res, cross):
265
+ if r is None or not o:
266
+ continue
267
+ tmp = _loop_cmp(
268
+ mapping_onnx_to_torch,
269
+ torch_results,
270
+ onnx_results,
271
+ o,
272
+ r,
273
+ r2,
274
+ verbose,
275
+ atol,
276
+ rtol,
277
+ i_torch,
278
+ i_onnx,
279
+ str_kws,
280
+ exc,
281
+ use_tensor,
79
282
  )
80
- return
81
- raise NotImplementedError(
82
- f"Validation for output type {type(outputs)} is not implemented, "
83
- f"node.name={node.name!r}, node.target={getattr(node, 'target', None)}, "
84
- f"node.args={node.args}, node.kwargs={node.kwargs}, "
85
- f"node.meta={node.meta}"
86
- )
283
+ if tmp is not None:
284
+ if tmp.ep_name in name_to_ep_node:
285
+ tmp.ep_id_node = name_to_ep_node[tmp.ep_name]
286
+ tmp.ep_target = str(ep_graph_nodes[tmp.ep_id_node].target)
287
+ tmp.ep_time_run = ep_durations[tmp.ep_id_node]
288
+ else:
289
+ tmp.ep_id_node = None
290
+ tmp.ep_target = None
291
+ tmp.ep_name = None
292
+ tmp.onnx_op_type = onx.graph.node[tmp.onnx_id_node].op_type
293
+ tmp.onnx_id_output = list_node_output.index(o)
294
+ tmp.onnx_time_run = duration
295
+ status.yielded_nodes += 1
296
+ if tmp.err_abs is not None:
297
+ status.update(tmp.err_abs)
298
+ tmp.comment = comment
299
+ yield tmp
87
300
 
301
+ # do we need to dump pieces if graph the user can replay?
302
+ if replay_configuration:
303
+ if replay_configuration.select(
304
+ name=tmp.onnx_name, op_type=tmp.onnx_op_type, err_abs=tmp.err_abs
305
+ ):
306
+ replay_configuration.dump(
307
+ name=tmp.onnx_name,
308
+ onnx_id_node=tmp.onnx_id_node,
309
+ model=onx,
310
+ onnx_results=onnx_results,
311
+ torch_results=torch_results,
312
+ onnx_name_to_ep_name={
313
+ **{v: k for k, v in torch_names_to_onnx_names.items()},
314
+ **mapping_onnx_to_torch,
315
+ },
316
+ verbose=max(verbose - 1, 0),
317
+ )
318
+ status.last_replay = tmp.onnx_name
88
319
 
89
- def run_fx_node(
90
- node: torch.fx.Node, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None
91
- ) -> Tuple[Any, ...]:
92
- """
93
- Executes a node
320
+ # reset_names: replaces onnx_results by torch_results to see
321
+ # if that fixes the discrepancies problem
322
+ if reset_names and tmp.ep_name in reset_names:
323
+ assert (
324
+ tmp.ep_name in torch_results
325
+ ), f"name {tmp.ep_name!r} set to be reset is missing in torch_results."
326
+ assert (
327
+ tmp.onnx_name in onnx_results
328
+ ), f"name {tmp.onnx_name!r} set to be reset is missing in onnx_results."
329
+ onnx_results[tmp.onnx_name] = torch_results[tmp.ep_name]
330
+ tmp = _loop_cmp(
331
+ mapping_onnx_to_torch,
332
+ torch_results,
333
+ onnx_results,
334
+ o,
335
+ torch_results[tmp.ep_name],
336
+ None,
337
+ verbose,
338
+ atol,
339
+ rtol,
340
+ i_torch,
341
+ i_onnx,
342
+ str_kws,
343
+ exc,
344
+ use_tensor,
345
+ )
346
+ assert tmp.err_abs == 0, f"Reset did not happen, tmp={tmp}"
347
+ if tmp is not None:
348
+ tmp.onnx_op_type = "reset"
349
+ tmp.onnx_id_output = list_node_output.index(o)
350
+ status.yielded_nodes += 1
351
+ yield tmp
352
+ already_run_onnx.add(i_onnx)
94
353
 
95
- :param node: runs a node
96
- :param args: unnamed inputs to the node
97
- :param kwargs: named inputs to the node
98
- :return: results
99
- """
100
- if node.op == "output":
101
- assert len(args) == 1 and not kwargs, (
102
- f"Unexpected inputs: args={string_type(args, limit=20)} "
103
- f"kwargs={string_type(kwargs, limit=20)}"
354
+
355
+ def _preparation_with_fx_graph(
356
+ ep_graph_nodes: List[torch.fx.Node],
357
+ name_to_ep_node: Dict[str, int],
358
+ torch_input_names: List[str],
359
+ onx: onnx.ModelProto,
360
+ torch_names_to_onnx_names: Dict[str, str],
361
+ skip_mapping_torch_onnx,
362
+ torch_results: Dict[str, torch.Tensor],
363
+ placeholders: Dict[str, torch.Tensor],
364
+ placeholders_to_state_dict: Dict[str, str],
365
+ ep_state_dict: Dict[str, torch.Tensor],
366
+ positions: Dict[str, Dict[str, int]],
367
+ ) -> List[str]:
368
+ torch_output_names = None
369
+ for i, node in enumerate(ep_graph_nodes):
370
+ if isinstance(node.name, str):
371
+ positions[node.name] = dict(fx=i)
372
+ else:
373
+ for n in node.name:
374
+ positions[n] = dict(fx=i)
375
+ if node.op == "placeholder":
376
+ if node.name in placeholders_to_state_dict:
377
+ # This a weight.
378
+ placeholders[node.name] = ep_state_dict[placeholders_to_state_dict[node.name]]
379
+ torch_results[node.name] = placeholders[node.name]
380
+ assert isinstance(torch_results[node.name], torch.Tensor), (
381
+ f"torch_results[{node.name}] not a tensor but "
382
+ f"{type(torch_results[node.name])}"
383
+ )
384
+ else:
385
+ # This is an input
386
+ assert len(torch_input_names) < len(onx.graph.input), (
387
+ f"torch_input_names={torch_input_names!r}, "
388
+ f"onnx_input_names={[n.name for n in onx.graph.input]}, "
389
+ f"node.name={node.name!r} cannot be an input, "
390
+ f"placeholders_to_state_dict={sorted(placeholders_to_state_dict)}"
391
+ )
392
+ assert node.name not in skip_mapping_torch_onnx, (
393
+ f"{node.name!r} is ambiguous, cannot be mapped due to "
394
+ f"{skip_mapping_torch_onnx}"
395
+ )
396
+ torch_names_to_onnx_names[node.name] = onx.graph.input[
397
+ len(torch_input_names)
398
+ ].name
399
+ torch_input_names.append(node.name)
400
+ elif node.op == "output":
401
+ torch_output_names = [n.name for n in node.args[0]]
402
+ assert isinstance(node.name, str), (
403
+ f"Unexpected type {type(node.name)} for node={node} (target={node.target}), "
404
+ f"args={node.args}"
104
405
  )
105
- return args
106
- if node.op == "call_function":
107
- assert callable(node.target), f"{node.target!r} not callable in node {node!r}"
108
- outputs = node.target(*args, **(kwargs or {}))
109
- validate_fx_outputs(node, outputs)
110
- return outputs
111
- raise NotImplementedError(
112
- f"node.op={node.op!r} is not implemented, node.name={node.name!r}"
406
+ name_to_ep_node[node.name] = i
407
+ assert torch_output_names is not None, "No output node ws found the graph."
408
+ return torch_output_names
409
+
410
+
411
+ def _preparation_with_onnx_model(
412
+ default_device,
413
+ use_tensor: bool,
414
+ onx: onnx.ModelProto,
415
+ already_yielded: Dict[str, Any],
416
+ str_kws: Dict[str, bool],
417
+ positions: Dict[str, Dict[str, int]],
418
+ torch_names_to_onnx_names: Dict[str, str],
419
+ torch_output_names: List[str],
420
+ torch_results: Dict[str, torch.Tensor],
421
+ skip_mapping_torch_onnx: Set[str],
422
+ verbose: int,
423
+ args: Sequence[Any],
424
+ kwargs: Dict[str, Any],
425
+ ) -> Tuple[Dict[str, str], Dict[str, torch.Tensor], float, float, List[RunAlignedRecord]]:
426
+ for inp in onx.graph.input:
427
+ n = inp.name
428
+ if n in positions:
429
+ positions[n]["onnx"] = -1
430
+ else:
431
+ positions[n] = dict(onnx=-1)
432
+ for inp in onx.graph.initializer:
433
+ n = inp.name
434
+ if n in positions:
435
+ positions[n]["onnx"] = -1
436
+ else:
437
+ positions[n] = dict(onnx=-1)
438
+ for i, node in enumerate(onx.graph.node):
439
+ for n in node.output:
440
+ if n in positions:
441
+ positions[n]["onnx"] = i
442
+ else:
443
+ positions[n] = dict(onnx=i)
444
+
445
+ onnx_outputs_names = [o.name for o in onx.graph.output]
446
+ assert torch_output_names is not None and len(torch_output_names) == len(
447
+ onnx_outputs_names
448
+ ), (
449
+ f"Unexpected number of outputs, torch_output_names={torch_output_names}, "
450
+ f"onnx_outputs_names={onnx_outputs_names}"
113
451
  )
452
+ mapping_onnx_to_torch = dict(zip(onnx_outputs_names, torch_output_names))
114
453
 
454
+ onnx_args = list(args) if args else []
455
+ if kwargs:
456
+ onnx_args.extend(flatten_object(kwargs, drop_keys=True))
457
+ if verbose:
458
+ print(f"[run_aligned] args: {string_type(args, **str_kws)}")
459
+ print(f"[run_aligned] kwargs: {string_type(kwargs, **str_kws)}")
460
+ print(f"[run_aligned] onnx: {string_type(onnx_args, **str_kws)}")
461
+ print(f"[run_aligned] nx: walks through {len(onx.graph.input)} onnx inputs")
462
+ onnx_results: Dict[str, Any] = {}
463
+ for inp, v in zip(onx.graph.input, onnx_args):
464
+ onnx_results[inp.name] = _check_tensor_(
465
+ use_tensor, inp.name, v if use_tensor else to_numpy(v)
466
+ )
467
+ if verbose:
468
+ print(f"[run_aligned-nx] +inp: {inp.name}: {string_type(v, **str_kws)}")
115
469
 
116
- def _pick_result(torch_results: Dict[str, Any], ref: Any) -> Any:
117
- "See :func:`prepare_args_kwargs`."
118
- if isinstance(ref, torch.fx.Node):
119
- return torch_results[ref.name]
120
- if isinstance(ref, list):
121
- return [_pick_result(torch_results, n) for n in ref]
122
- if isinstance(ref, tuple):
123
- return tuple(_pick_result(torch_results, n) for n in ref)
124
- if isinstance(ref, dict):
125
- return {k: _pick_result(torch_results, v) for k, v in ref.items()}
126
- if isinstance(ref, (bool, int, float, str, torch.device, torch.dtype)):
127
- return ref
128
- if ref is None:
129
- return None
130
- raise NotImplementedError(f"Unable to process args type {type(ref)}")
131
-
132
-
133
- def prepare_args_kwargs(
134
- torch_results: Dict[str, Any], node: torch.fx.Node
135
- ) -> Tuple[Tuple[Any, ...], Dict[str, Any]]:
136
- """
137
- Prepares args and kwargs before executing a fx node.
470
+ # alias for initializers
471
+ skip_onnx_name = set()
472
+ init_aliases: Dict[str, str] = {}
473
+ for init in onx.graph.initializer:
474
+ new_names = {
475
+ n
476
+ for n in [
477
+ f"p_{init.name.replace('.', '_')}",
478
+ f"p_{init.name.split('::')[0].split('--')[-1].replace('.', '_')}",
479
+ f"{init.name.split('::')[0].split('--')[-1].replace('.', '_')}",
480
+ ]
481
+ if n != init.name
482
+ }
483
+ drop = False
484
+ for new_name in new_names:
485
+ if new_name in skip_onnx_name:
486
+ drop = True
487
+ break
488
+ if drop:
489
+ skip_onnx_name |= new_names | {init.name}
490
+ for new_name in new_names:
491
+ if new_names in init_aliases:
492
+ del init_aliases[new_name]
493
+ else:
494
+ for new_name in new_names:
495
+ init_aliases[new_name] = init.name
496
+ rev_init_aliases: Dict[str, Set[str]] = {}
497
+ for k, v in init_aliases.items():
498
+ if v in rev_init_aliases:
499
+ rev_init_aliases[v].add(k)
500
+ else:
501
+ rev_init_aliases[v] = {k}
138
502
 
139
- :param torch_results: existing results
140
- :param node: node to execute
141
- :return: new args and kwargs
142
- """
143
- new_args = _pick_result(torch_results, node.args)
144
- new_kwargs = _pick_result(torch_results, node.kwargs)
145
- return new_args, new_kwargs
503
+ # initializers
504
+ if verbose:
505
+ print(f"[run_aligned] nx: handles {len(onx.graph.initializer)} initializers from onnx")
506
+ memory_cpu = 0
507
+ memory_cuda = 0
508
+ records_to_yield = []
509
+ for init in onx.graph.initializer: # type: ignore
510
+ t = None
511
+ if init.name in torch_results:
512
+ if init.name not in skip_mapping_torch_onnx:
513
+ t = torch_results[init.name]
514
+ torch_names_to_onnx_names[init.name] = init.name
515
+ elif init.name not in skip_onnx_name and init.name in rev_init_aliases:
516
+ new_names = [
517
+ k
518
+ for k in rev_init_aliases[init.name]
519
+ if k in torch_results and k not in skip_mapping_torch_onnx
520
+ ]
521
+ if new_names and len(new_names) == 1:
522
+ new_name = new_names[0] # type: ignore[assignment, index]
523
+ t = torch_results[new_name]
524
+ if (
525
+ len(set(t.shape)) == len(t.shape) # not repeated dimension
526
+ and t.shape == tuple(init.dims)
527
+ and torch_dtype_to_onnx_dtype(t.dtype) == init.data_type
528
+ ):
529
+ torch_names_to_onnx_names[new_name] = init.name
530
+ else:
531
+ t = None
532
+
533
+ # We should check tensors and proto are the same.
534
+ if t is None:
535
+ t = to_tensor(init)
536
+ if default_device and t.numel() >= 1024:
537
+ # Let's force its way to cuda (should check the device has well).
538
+ t = t.to(default_device)
539
+ records_to_yield.append(
540
+ RunAlignedRecord(
541
+ onnx_id_node=-1,
542
+ onnx_name=init.name,
543
+ onnx_op_type="initializer",
544
+ onnx_shape_type=string_type(t, **str_kws),
545
+ ).check(already_yielded)
546
+ )
547
+
548
+ size = t.element_size() * t.numel()
549
+ if t.is_cuda:
550
+ memory_cuda += size
551
+ else:
552
+ memory_cpu += size
553
+ if init.name not in onnx_results:
554
+ # otherwise, it is an input with a default value
555
+ onnx_results[init.name] = _check_tensor_(use_tensor, init.name, t, flip_type=True)
556
+ return mapping_onnx_to_torch, onnx_results, memory_cpu, memory_cuda, records_to_yield
146
557
 
147
558
 
148
559
  def run_aligned(
149
560
  ep: torch.export.ExportedProgram,
150
561
  onx: Union[onnx.ModelProto, onnx.FunctionProto],
151
- args: Tuple[torch.Tensor, ...],
152
- check_conversion_cls: Union[Dict[str, Any], type],
562
+ run_cls: Callable[
563
+ [
564
+ Union[
565
+ onnx.ModelProto,
566
+ onnx.FunctionProto,
567
+ onnx.GraphProto,
568
+ onnx.NodeProto,
569
+ ]
570
+ ],
571
+ List[Union[np.ndarray, torch.Tensor]],
572
+ ],
573
+ args: Optional[Tuple[torch.Tensor, ...]] = None,
153
574
  kwargs: Optional[Dict[str, Any]] = None,
575
+ use_tensor: bool = False,
576
+ atol: Optional[float] = None,
577
+ rtol: Optional[float] = None,
154
578
  verbose: int = 0,
155
- ) -> Iterator[Tuple[Any, ...]]:
579
+ exc: bool = True,
580
+ reset_names: Optional[List[str]] = None,
581
+ replay_configuration: Optional[ReplayConfiguration] = None,
582
+ run_onnx_with_torch_inputs: bool = False,
583
+ ) -> Iterator[RunAlignedRecord]:
156
584
  """
157
585
  Runs in parallel both the exported program
158
586
  and the onnx proto and looks for discrepancies.
@@ -162,11 +590,25 @@ def run_aligned(
162
590
 
163
591
  :param ep: exported program
164
592
  :param onx: model or function proto
593
+ :param run_cls: defines the runtime to use for this task
165
594
  :param args: input args
166
- :param check_conversion_cls: defines the runtime to use for this task
167
595
  :param kwargs: input kwargs
596
+ :param use_tensor: use torch tensors instead of numpy arrays
597
+ for the onnx runtime
598
+ :param atol: absolute tolerance
599
+ :param rtol: relative tolerance
168
600
  :param verbose: verbosity level
169
- :return: a list of tuples containing the results, they come in tuple,
601
+ :param exc: stops if an exception
602
+ :param reset_names: list of names, the onnx execution takes the torch outputs instead
603
+ of its own result if the names falls into that set
604
+ :param replay_configuration: configuration to let the user dump any problematic
605
+ piece of the onnx graph he wants to replay in order to investigate later,
606
+ see :class: `ReplayConfiguration
607
+ <onnx_diagnostic.torch_onnx.sbs.ReplayConfiguration>`
608
+ :param run_onnx_with_torch_inputs: run an onnx operator with torch results
609
+ if they available
610
+ :return: a list of :class:`RunAlignedRecord
611
+ <onnx_diagnostic.torch_onnx.sbs_dataclasses.RunAlignedRecord>`
170
612
 
171
613
  Example:
172
614
 
@@ -174,11 +616,10 @@ def run_aligned(
174
616
  :showcode:
175
617
  :warningout: UserWarning
176
618
 
177
- import pprint
178
619
  import pandas
179
620
  import torch
180
621
  from onnx_diagnostic.reference import (
181
- # This can be replace by any runtime taking NodeProto as an input.
622
+ # This can be replaced by any runtime taking NodeProto as an input.
182
623
  ExtendedReferenceEvaluator as ReferenceEvaluator,
183
624
  )
184
625
  from onnx_diagnostic.torch_onnx.sbs import run_aligned
@@ -193,13 +634,6 @@ def run_aligned(
193
634
  return ru
194
635
 
195
636
 
196
- def post_process(obs):
197
- dobs = dict(zip(["ep_id_node", "onnx_id_node", "ep_name", "onnx_name"], obs))
198
- dobs["err_abs"] = obs[-1]["abs"]
199
- dobs["err_rel"] = obs[-1]["rel"]
200
- return dobs
201
-
202
-
203
637
  x = torch.randn((5, 4))
204
638
  Model()(x) # to make sure the model is running
205
639
  ep = torch.export.export(
@@ -209,131 +643,341 @@ def run_aligned(
209
643
  Model(), (x,), dynamic_shapes=({0: torch.export.Dim("batch")},)
210
644
  ).model_proto
211
645
  results = list(
212
- map(
213
- post_process,
214
- run_aligned(
215
- ep,
216
- onx,
217
- (x,),
218
- check_conversion_cls=dict(cls=ReferenceEvaluator, atol=1e-5, rtol=1e-5),
219
- verbose=1,
220
- ),
221
- ),
646
+ run_aligned(ep, onx, ReferenceEvaluator, (x,), atol=1e-5, rtol=1e-5, verbose=1)
222
647
  )
223
648
  print("------------")
224
649
  print("final results")
225
650
  df = pandas.DataFrame(results)
651
+ df = df.apply(lambda col: col.fillna("") if col.dtype == "object" else col)
226
652
  print(df)
227
- """
228
- assert not kwargs, f"Not implemented when kwargs={string_type(kwargs,with_shape=True)}"
229
- cls, atol, rtol = (
230
- (
231
- check_conversion_cls["cls"],
232
- check_conversion_cls["atol"],
233
- check_conversion_cls["rtol"],
653
+
654
+ This example uses :class:`onnx.reference.ReferenceEvaluator` to run the onnx model
655
+ but onnxruntime can also be used through
656
+ :class:`onnx_diagnostic.helpers.ort_session.InferenceSessionForTorch`.
657
+ It relies on :epkg:`onnxruntime` and selects CPU or CUDA depending
658
+ on the device where the inputs are located.
659
+
660
+ The :class:`torch.export.ExportedProgram` can be saved on disk
661
+ with ``ep.save("<filename>.pt")`` and restored with
662
+ ``torch.export.load("<filename>.pt")``. That leeds the input to save.
663
+ We can decouple the export and the alignment.
664
+
665
+ .. runpython::
666
+ :showcode:
667
+ :warningout: UserWarning
668
+
669
+ import onnx
670
+ import torch
671
+ from onnx_diagnostic.torch_export_patches.patch_inputs import use_dyn_not_str
672
+
673
+
674
+ class Model(torch.nn.Module):
675
+ def forward(self, x):
676
+ ry = x.abs()
677
+ rz = ry.exp()
678
+ rw = rz + 1
679
+ ru = rw.log() + rw
680
+ return ru
681
+
682
+
683
+ x = torch.randn((5, 4))
684
+ dynamic_shapes = ({0: "batch"},)
685
+ Model()(x) # to make sure the model is running
686
+ ep = torch.export.export(
687
+ Model(), (x,), dynamic_shapes=use_dyn_not_str(dynamic_shapes)
234
688
  )
235
- if isinstance(check_conversion_cls, dict)
236
- else (check_conversion_cls, None, None)
237
- )
689
+ onx = torch.onnx.export(
690
+ Model(), (x,), dynamic_shapes=dynamic_shapes
691
+ ).model_proto
238
692
 
239
- # retrieve the positions
240
- positions: Dict[str, Any] = {}
241
- for i, node in enumerate(ep.graph.nodes):
242
- if isinstance(node.name, str):
243
- positions[node.name] = dict(fx=i)
244
- else:
245
- for n in node.name:
246
- positions[n] = dict(fx=i)
693
+ torch.export.save(ep, "test_doc_sbs_example.pt2")
694
+ onnx.save(onx, "test_doc_sbs_example.onnx")
695
+ torch.save((x,), "test_doc_sbs_example.pt")
247
696
 
248
- for i, node in enumerate(onx.graph.node):
249
- for n in node.output:
250
- if n in positions:
251
- positions[n]["onnx"] = i
252
- else:
253
- positions[n] = dict(onnx=i)
697
+ Then we can restore all of them and run it.
254
698
 
255
- onnx_results: Dict[str, Any] = {}
256
- for init in onx.graph.initializer: # type: ignore
257
- positions[init.name] = -1
258
- onnx_results[init.name] = to_array_extended(init)
259
- param_name = f"p_{init.name.replace('.', '_')}"
260
- if param_name == init.name:
261
- continue
262
- assert param_name not in onnx_results, (
263
- f"Some confusion may happen because {init.name!r} -> {param_name!r} "
264
- f"and onnx_results has {sorted(onnx_results)}"
699
+ .. runpython::
700
+ :showcode:
701
+ :warningout: UserWarning
702
+
703
+ import pandas
704
+ import onnx
705
+ import torch
706
+ from onnx_diagnostic.torch_onnx.sbs import run_aligned
707
+ from onnx_diagnostic.reference import OnnxruntimeEvaluator
708
+
709
+
710
+ ep = torch.export.load("test_doc_sbs_example.pt2")
711
+ onx = onnx.load("test_doc_sbs_example.onnx")
712
+ inputs = torch.load("test_doc_sbs_example.pt")
713
+
714
+
715
+ results = list(
716
+ run_aligned(
717
+ ep,
718
+ onx,
719
+ OnnxruntimeEvaluator,
720
+ inputs,
721
+ atol=1e-5,
722
+ rtol=1e-5,
723
+ verbose=1,
724
+ use_tensor=True,
725
+ )
265
726
  )
266
- onnx_results[param_name] = onnx_results[init.name]
727
+ print("------------")
728
+ print("final results")
729
+ df = pandas.DataFrame(results)
730
+ df = df.apply(lambda col: col.fillna("") if col.dtype == "object" else col)
731
+ print(df)
732
+
733
+ A command line can also be run:
734
+
735
+ .. code-block:: bash
267
736
 
268
- torch_results: Dict[str, Any] = {
269
- k: torch.from_numpy(v.copy())
270
- for k, v in onnx_results.items()
271
- if not k.startswith("init")
737
+ python -m onnx_diagnostic sbs -i <tensors>.input.pt \\
738
+ --ep <exported_program>.pt2 \\
739
+ -m <model>.onnx \\
740
+ -o results.xlsx \\
741
+ -v 1 --atol=0.1 --rtol=1
742
+ """
743
+ assert callable(run_cls), f"run_cls={run_cls} not a callable"
744
+ already_yielded = {}
745
+ reset_names = set(reset_names) if reset_names else set()
746
+ str_kws = dict(with_shape=True, with_device=True)
747
+ has_cuda = any(
748
+ (isinstance(t, torch.Tensor) and t.is_cuda)
749
+ for t in flatten_object([args, kwargs], drop_keys=True)
750
+ )
751
+ default_device = None
752
+ if has_cuda:
753
+ for t in flatten_object([args, kwargs], drop_keys=True):
754
+ if t is not None and t.is_cuda:
755
+ default_device = t.device
756
+ break
757
+ run_cls_kwargs = {
758
+ "ir_version": onx.ir_version,
759
+ "opsets": {d.domain: d.version for d in onx.opset_import},
760
+ "verbose": max(verbose - 1, 0),
761
+ "providers": (
762
+ ["CUDAExecutionProvider", "CPUExecutionProvider"]
763
+ if has_cuda
764
+ else ["CPUExecutionProvider"]
765
+ ),
272
766
  }
273
- last_position = 0
767
+ run_cls_kwargs = {
768
+ k: v
769
+ for k, v in run_cls_kwargs.items()
770
+ if k in set(inspect.signature(run_cls).parameters)
771
+ }
772
+ if verbose:
773
+ print(f"[run_aligned] run_cls={run_cls}")
774
+ print(f"[run_aligned] run_cls_kwargs={run_cls_kwargs}")
775
+ if replay_configuration:
776
+ print(f"[run_aligned] replay={replay_configuration}")
777
+
778
+ # preparation with ep.graph.nodes
779
+ ep_state_dict = {**ep.state_dict, **dict(ep.named_buffers(), **ep.tensor_constants)}
780
+ placeholders_to_state_dict = {
781
+ **{f"p_{name.replace('.', '_').lower()}": name for name in ep.state_dict},
782
+ **{f"b_{name.replace('.', '_').lower()}": name for name, _ in ep.named_buffers()},
783
+ **{f"c_{name.replace('.', '_').lower()}": name for name in ep.tensor_constants},
784
+ }
785
+ skip_mapping_torch_onnx = _duplicated_values(placeholders_to_state_dict)
786
+ placeholders = {}
787
+ if verbose:
788
+ print(f"[run_aligned] ep: model has {len(ep_state_dict)} torch constants or weights.")
789
+
790
+ if verbose:
791
+ print(f"[run_aligned] ep: walks through {len(ep.graph.nodes)} nodes from torch")
792
+
793
+ # dictionary mapping result names and their position in both graphs.
794
+ positions: Dict[str, Dict[str, int]] = {}
795
+ ep_graph_nodes = list(ep.graph.nodes)
796
+ torch_results: Dict[str, Any] = {}
274
797
  torch_output_names = None
275
- for node in ep.graph.nodes:
276
- if node.op == "output":
277
- torch_output_names = [n.name for n in node.args[0]]
278
- onnx_outputs_names = [o.name for o in onx.graph.output]
279
- assert torch_output_names is not None and len(torch_output_names) == len(
280
- onnx_outputs_names
281
- ), (
282
- f"Unexpected number of outputs, torch_output_names={torch_output_names}, "
283
- f"onnx_outputs_names={onnx_outputs_names}"
798
+ torch_input_names: List[str] = []
799
+ name_to_ep_node = {}
800
+ torch_names_to_onnx_names = {}
801
+ torch_output_names = _preparation_with_fx_graph(
802
+ ep_graph_nodes,
803
+ name_to_ep_node,
804
+ torch_input_names,
805
+ onx,
806
+ torch_names_to_onnx_names,
807
+ skip_mapping_torch_onnx,
808
+ torch_results,
809
+ placeholders,
810
+ placeholders_to_state_dict,
811
+ ep_state_dict,
812
+ positions,
284
813
  )
285
- mapping_onnx_to_torch = dict(zip(onnx_outputs_names, torch_output_names))
286
814
 
815
+ # prepration for onnx
287
816
  if verbose:
288
- for k, v in torch_results.items():
289
- print(
290
- f"[run_aligned] +torch-cst: {k}: "
291
- f"{string_type(v, with_shape=True, with_min_max=True)}"
292
- )
293
- for k, v in onnx_results.items():
294
- print(
295
- f"[run_aligned] +onnx-init: {k}: "
296
- f"{string_type(v, with_shape=True, with_min_max=True)}"
297
- )
817
+ print(f"[run_aligned] ep: found {len(torch_results)} torch constants or weights.")
818
+ print(f"[run_aligned] ep: found inputs {torch_input_names}")
819
+ print(f"[run_aligned] ep: found outputs {torch_output_names}")
820
+ print(f"[run_aligned] nx: walks through {len(onx.graph.node)} nodes from onnx")
298
821
 
299
- for inp, v in zip(onx.graph.input, args):
300
- onnx_results[inp.name] = to_numpy(v)
301
- if verbose:
302
- print(
303
- f"[run_aligned] +onnx-input: {inp.name}: "
304
- f"{string_type(v, with_shape=True, with_min_max=True)}"
305
- )
822
+ mapping_onnx_to_torch, onnx_results, memory_cpu, memory_cuda, records_to_yield = (
823
+ _preparation_with_onnx_model(
824
+ default_device,
825
+ use_tensor,
826
+ onx,
827
+ already_yielded,
828
+ str_kws,
829
+ positions,
830
+ torch_names_to_onnx_names,
831
+ torch_output_names,
832
+ torch_results,
833
+ skip_mapping_torch_onnx,
834
+ verbose,
835
+ args,
836
+ kwargs,
837
+ )
838
+ )
839
+ for record in records_to_yield:
840
+ yield record
306
841
 
307
- for i, node in enumerate(ep.graph.nodes):
308
- if verbose:
842
+ if verbose:
843
+ print(f"[run_aligned] nx: handled {len(onnx_results)} initializers from onnx")
844
+ print(f"[run_aligned] nx: memory cpu {memory_cpu / 2**20:.3f} Mb")
845
+ print(f"[run_aligned] nx: memory cuda {memory_cuda / 2**20:.3f} Mb")
846
+ print(f"[run_aligned] nx: {len(onnx_results)} constants")
847
+ print(f"[run_aligned] nx: {len(onx.graph.input)} inputs")
848
+ print(f"[run_aligned] nx: {len(onx.graph.output)} outputs")
849
+ print(f"[run_aligned] bo: {len(mapping_onnx_to_torch)} outputs")
850
+ print(f"[run_aligned] run_cls_kwargs={run_cls_kwargs}")
851
+ if verbose > 1:
852
+ for k, v in torch_results.items():
853
+ print(f"[run_aligned-ep] +cst: {k}: {string_type(v, **str_kws)}")
854
+ for k, v in onnx_results.items():
855
+ print(f"[run_aligned-nx] +ini: {k}: {string_type(v, **str_kws)}")
856
+
857
+ # starts the side-by-side
858
+ if verbose:
859
+ print(
860
+ f"[run_aligned] ep: starts side-by-side with {len(ep_graph_nodes)} "
861
+ f"fx nodes and {len(onx.graph.node)} onnx nodes"
862
+ )
863
+ if verbose == 1:
864
+ import tqdm
865
+
866
+ loop = tqdm.tqdm(total=len(ep_graph_nodes) + len(onx.graph.node))
867
+ else:
868
+ loop = None
869
+
870
+ already_run: Set[int] = set()
871
+ ep_durations = {}
872
+ status = StatusRunAligned()
873
+ last_position = 0
874
+ for i_torch, node in enumerate(ep_graph_nodes):
875
+ if verbose > 1:
309
876
  if node.op == "call_function":
310
877
  print(
311
- f"[run_aligned] run ep.graph.nodes[{i}]: "
878
+ f"[run_aligned] run ep.graph.nodes[{i_torch}]: "
312
879
  f"{node.op}[{node.target}] -> {node.name!r}"
313
880
  )
314
881
  else:
315
- print(f"[run_aligned] run ep.graph.nodes[{i}]: {node.op} -> {node.name!r}")
882
+ print(
883
+ f"[run_aligned] run ep.graph.nodes[{i_torch}]: {node.op} -> {node.name!r}"
884
+ )
885
+ elif verbose == 1:
886
+ loop.set_description(
887
+ f"ep {i_torch}/{len(ep_graph_nodes)} nx {last_position}/{len(onx.graph.node)} "
888
+ f"{status.to_str()}"
889
+ )
890
+ loop.update(min(1, 1 + i_torch + last_position))
316
891
 
317
892
  if node.op == "placeholder":
318
- if node.name in onnx_results:
319
- torch_results[node.name] = torch.from_numpy(onnx_results[node.name].copy())
320
- if verbose:
321
- t = torch_results[node.name]
322
- print(
323
- f"[run_aligned] +torch {node.name}="
324
- f"{string_type(t, with_shape=True, with_min_max=True)}"
893
+ is_input = node.name not in placeholders
894
+ if is_input:
895
+ torch_results[node.name] = (
896
+ onnx_results[torch_names_to_onnx_names[node.name]]
897
+ if use_tensor
898
+ else from_numpy(onnx_results[torch_names_to_onnx_names[node.name]])
899
+ )
900
+ assert isinstance(torch_results[node.name], torch.Tensor), (
901
+ f"torch_results[{node.name}] not a tensor but "
902
+ f"{type(torch_results[node.name])}, use_tensor={use_tensor}"
903
+ )
904
+ t = torch_results[node.name]
905
+ if verbose > 1:
906
+ print(f"[run_aligned-ep] =ags: {node.name}={string_type(t, **str_kws)}")
907
+ # Otherwise, it is an input.
908
+ record = RunAlignedRecord(
909
+ ep_id_node=i_torch,
910
+ onnx_id_node=-1,
911
+ ep_name=node.name,
912
+ onnx_name=torch_names_to_onnx_names[node.name],
913
+ ep_target="input",
914
+ onnx_op_type="input",
915
+ ep_shape_type=string_type(t, **str_kws),
916
+ onnx_shape_type=string_type(
917
+ onnx_results[torch_names_to_onnx_names[node.name]], **str_kws
918
+ ),
919
+ )
920
+ yield record.check(already_yielded)
921
+ else:
922
+ assert node.name in placeholders_to_state_dict, (
923
+ f"Unable to find placeholder {node.name!r} (node.op={node.op!r}), "
924
+ f"existing: {sorted(placeholders_to_state_dict)}"
925
+ )
926
+ assert node.name in torch_results, (
927
+ f"placeholder {node.name!r} (node.op={node.op!r}), "
928
+ f"should have been added to torch_results: {sorted(torch_results)}"
929
+ )
930
+ t = torch_results[node.name]
931
+ if (
932
+ node.name in torch_names_to_onnx_names
933
+ and node.name not in skip_mapping_torch_onnx
934
+ ):
935
+ if verbose > 1:
936
+ print(
937
+ f"[run_aligned-ep] =plh: "
938
+ f"{node.name}={string_type(t, **str_kws)}"
939
+ )
940
+ record = RunAlignedRecord(
941
+ ep_id_node=i_torch,
942
+ onnx_id_node=-1,
943
+ ep_name=node.name,
944
+ onnx_name=torch_names_to_onnx_names[node.name],
945
+ ep_target="placeholder",
946
+ onnx_op_type="initializer",
947
+ ep_shape_type=string_type(t, **str_kws),
948
+ onnx_shape_type=string_type(
949
+ onnx_results[torch_names_to_onnx_names[node.name]], **str_kws
950
+ ),
325
951
  )
326
- continue
327
- raise AssertionError(
328
- f"unable to process node {node.op} -> {node.name!r} "
329
- f"not in {sorted(onnx_results)}, len(args)={len(args)}, "
330
- f"onx.graph.input={[i.name for i in onx.graph.input]}"
331
- )
952
+ if not is_input:
953
+ record.set_diff(
954
+ max_diff(
955
+ t,
956
+ onnx_results[torch_names_to_onnx_names[node.name]],
957
+ hist=[0.1, 0.01],
958
+ )
959
+ )
960
+ yield record.check(already_yielded)
961
+ else:
962
+ if verbose > 1:
963
+ print(
964
+ f"[run_aligned-ep] +plh: {node.name}={string_type(t, **str_kws)}"
965
+ )
966
+ yield RunAlignedRecord(
967
+ ep_id_node=i_torch,
968
+ ep_name=node.name,
969
+ ep_target="placeholder",
970
+ ep_shape_type=string_type(t, **str_kws),
971
+ ).check(already_yielded)
972
+ continue
332
973
 
333
974
  outputs = [node.name] if isinstance(node.name, str) else list(node.name)
334
975
  args, kwargs = prepare_args_kwargs(torch_results, node)
976
+ begin = time.perf_counter()
335
977
  new_outputs = run_fx_node(node, args, kwargs)
336
- if isinstance(new_outputs, (torch.Tensor, int, float, list)):
978
+ duration = time.perf_counter() - begin
979
+ ep_durations[i_torch] = duration
980
+ if isinstance(new_outputs, (torch.Tensor, int, float, list, tuple)):
337
981
  new_outputs = (new_outputs,)
338
982
 
339
983
  if new_outputs is None:
@@ -342,99 +986,112 @@ def run_aligned(
342
986
 
343
987
  for k, v in zip(outputs, new_outputs):
344
988
  torch_results[k] = v
345
- if verbose:
989
+ if verbose > 1:
346
990
  for k, v in zip(outputs, new_outputs):
347
- print(
348
- f"[run_aligned] +torch {k}="
349
- f"{string_type(v, with_shape=True, with_min_max=True)}"
350
- )
991
+ print(f"[run_aligned-ep] +res: {k}={string_type(v, **str_kws)}")
351
992
 
352
993
  max_pos = -2
353
994
  for n in outputs:
354
- if n in positions and "onnx" in positions[n]:
355
- max_pos = max(max_pos, positions[n]["onnx"])
995
+ if n in positions:
996
+ if "onnx" in positions[n]:
997
+ max_pos = max(max_pos, positions[n]["onnx"])
998
+ if "fx" in positions[n]:
999
+ if positions[n]["fx"] > i_torch:
1000
+ max_pos = -2
1001
+ break
356
1002
  if max_pos == -2:
357
1003
  # we skip.
358
1004
  continue
359
1005
 
1006
+ next_to_visit = last_position
360
1007
  for i_onnx in range(last_position, max_pos + 1):
1008
+ if i_onnx in already_run:
1009
+ continue
1010
+ # The onnx node may produce more than one output, in that
1011
+ # case, we need to check the exported program is not behind.
361
1012
  node = onx.graph.node[i_onnx]
362
- if verbose:
363
- print(
364
- f"[run_aligned] run onx.graph.node[{i_onnx}]: "
365
- f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
366
- )
367
- ref = cls(node)
368
- feeds = {k: onnx_results[k] for k in node.input}
369
- res = ref.run(None, feeds)
370
- for o, r in zip(node.output, res):
371
- onnx_results[o] = r
372
- if verbose:
373
- print(
374
- f"[run_aligned] +onnx {o}="
375
- f"{string_type(r, with_shape=True, with_min_max=True)}"
376
- )
1013
+ ep_behind = False
1014
+ for iname in node.output:
1015
+ if iname in positions and "fx" in positions[iname]:
1016
+ if positions[iname]["fx"] > i_torch:
1017
+ ep_behind = True
1018
+ break
1019
+ if ep_behind:
1020
+ break
377
1021
 
378
- to = mapping_onnx_to_torch.get(o, o)
379
- if to in torch_results:
380
- d = max_diff(torch_results[to], r)
381
- if verbose:
382
- if o == to:
383
- print(f"[run_aligned] =common results {to}: {string_diff(d)}")
384
- else:
385
- print(f"[run_aligned] =common results {to}/{o}: {string_diff(d)}")
386
- if not (
387
- atol is None
388
- or rtol is None
389
- or (d["abs"] <= atol and d["rel"] <= rtol)
390
- ):
391
- skw = dict(with_shape=True, with_min_max=True)
392
- raise ValueError(
393
- f"discrepancies detected for results [{to}/{o}]: "
394
- f"{string_diff(d)}"
395
- f"\n-- torch_results: {string_type(torch_results[to], **skw)}"
396
- f"\n-- onnx_results: {string_type(r, **skw)}"
397
- f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}"
398
- )
399
- yield (i, i_onnx, o, to, d)
1022
+ for r in _loop_onnx_node(
1023
+ onx,
1024
+ ep_graph_nodes,
1025
+ onnx_results,
1026
+ mapping_onnx_to_torch,
1027
+ torch_results,
1028
+ ep_durations,
1029
+ use_tensor,
1030
+ i_torch,
1031
+ i_onnx,
1032
+ name_to_ep_node,
1033
+ run_cls_kwargs,
1034
+ str_kws,
1035
+ status,
1036
+ already_run,
1037
+ torch_names_to_onnx_names,
1038
+ verbose,
1039
+ exc,
1040
+ atol,
1041
+ rtol,
1042
+ reset_names,
1043
+ replay_configuration,
1044
+ has_cuda,
1045
+ run_cls,
1046
+ loop,
1047
+ run_onnx_with_torch_inputs,
1048
+ ):
1049
+ if r:
1050
+ yield r.check(already_yielded)
1051
+ next_to_visit = i_onnx + 1
400
1052
 
401
- last_position = max_pos + 1
1053
+ last_position = next_to_visit
402
1054
 
403
1055
  # complete the execution of the onnx graph
1056
+ if verbose > 1:
1057
+ print(
1058
+ f"[run_aligned] complete execution of onnx graph from pos={last_position} "
1059
+ f"to {len(onx.graph.node)}"
1060
+ )
404
1061
  for i_onnx in range(last_position, len(onx.graph.node)):
405
- node = onx.graph.node[i_onnx]
406
- if verbose:
407
- print(
408
- f"[run_aligned] run onx.graph.node[{i_onnx}]: "
409
- f"{node.op_type}({', '.join(node.input)}) -> {', '.join(node.output)}"
410
- )
411
- ref = cls(node)
412
- feeds = {k: onnx_results[k] for k in node.input}
413
- res = ref.run(None, feeds)
414
- for o, r in zip(node.output, res):
415
- onnx_results[o] = r
416
- if verbose:
417
- print(
418
- f"[run_aligned] +onnx {o}="
419
- f"{string_type(r, with_shape=True, with_min_max=True)}"
420
- )
1062
+ if i_onnx in already_run:
1063
+ continue
1064
+ for r in _loop_onnx_node(
1065
+ onx,
1066
+ ep_graph_nodes,
1067
+ onnx_results,
1068
+ mapping_onnx_to_torch,
1069
+ torch_results,
1070
+ ep_durations,
1071
+ use_tensor,
1072
+ i_torch,
1073
+ i_onnx,
1074
+ name_to_ep_node,
1075
+ run_cls_kwargs,
1076
+ str_kws,
1077
+ status,
1078
+ already_run,
1079
+ torch_names_to_onnx_names,
1080
+ verbose,
1081
+ exc,
1082
+ atol,
1083
+ rtol,
1084
+ reset_names,
1085
+ replay_configuration,
1086
+ has_cuda,
1087
+ run_cls,
1088
+ loop,
1089
+ run_onnx_with_torch_inputs,
1090
+ ):
1091
+ if r:
1092
+ yield r.check(already_yielded)
421
1093
 
422
- to = mapping_onnx_to_torch.get(o, o)
423
- if to in torch_results:
424
- d = max_diff(torch_results[to], r)
425
- if verbose:
426
- if o == to:
427
- print(f"[run_aligned] =common results* {to}: {string_diff(d)}")
428
- else:
429
- print(f"[run_aligned] =common results* {to}/{o}: {string_diff(d)}")
430
- if not (
431
- atol is None or rtol is None or (d["abs"] <= atol and d["rel"] <= rtol)
432
- ):
433
- skw = dict(with_shape=True, with_min_max=True)
434
- raise ValueError(
435
- f"discrepancies detected for results* [{to}/{o}]: {string_diff(d)}"
436
- f"\n-- torch_results: {string_type(torch_results[to], **skw)}"
437
- f"\n-- onnx_results: {string_type(r, **skw)}"
438
- f"\n-- torch\n{torch_results[to]}\n-- onnx\n{r}"
439
- )
440
- yield (i, i_onnx, o, to, d)
1094
+ if loop is not None:
1095
+ loop.close()
1096
+ if verbose:
1097
+ print(f"[run_aligned] done with status={status.to_str()}")