onnx-diagnostic 0.6.3__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. onnx_diagnostic/__init__.py +1 -1
  2. onnx_diagnostic/_command_lines_parser.py +87 -77
  3. onnx_diagnostic/doc.py +22 -0
  4. onnx_diagnostic/ext_test_case.py +1 -1
  5. onnx_diagnostic/helpers/cache_helper.py +59 -0
  6. onnx_diagnostic/helpers/config_helper.py +8 -4
  7. onnx_diagnostic/helpers/helper.py +30 -3
  8. onnx_diagnostic/helpers/log_helper.py +585 -0
  9. onnx_diagnostic/helpers/mini_onnx_builder.py +4 -1
  10. onnx_diagnostic/helpers/model_builder_helper.py +54 -73
  11. onnx_diagnostic/helpers/torch_helper.py +18 -2
  12. onnx_diagnostic/reference/__init__.py +1 -0
  13. onnx_diagnostic/reference/ort_evaluator.py +29 -4
  14. onnx_diagnostic/reference/report_results_comparison.py +95 -0
  15. onnx_diagnostic/reference/torch_evaluator.py +21 -0
  16. onnx_diagnostic/tasks/automatic_speech_recognition.py +3 -0
  17. onnx_diagnostic/tasks/feature_extraction.py +3 -0
  18. onnx_diagnostic/tasks/fill_mask.py +3 -0
  19. onnx_diagnostic/tasks/image_classification.py +7 -1
  20. onnx_diagnostic/tasks/image_text_to_text.py +3 -0
  21. onnx_diagnostic/tasks/mixture_of_expert.py +3 -0
  22. onnx_diagnostic/tasks/object_detection.py +3 -0
  23. onnx_diagnostic/tasks/sentence_similarity.py +3 -0
  24. onnx_diagnostic/tasks/summarization.py +3 -0
  25. onnx_diagnostic/tasks/text2text_generation.py +3 -0
  26. onnx_diagnostic/tasks/text_classification.py +3 -0
  27. onnx_diagnostic/tasks/text_generation.py +90 -43
  28. onnx_diagnostic/tasks/zero_shot_image_classification.py +3 -0
  29. onnx_diagnostic/torch_export_patches/onnx_export_errors.py +78 -25
  30. onnx_diagnostic/torch_export_patches/onnx_export_serialization.py +37 -0
  31. onnx_diagnostic/torch_export_patches/patches/patch_transformers.py +365 -17
  32. onnx_diagnostic/torch_models/hghub/hub_api.py +20 -4
  33. onnx_diagnostic/torch_models/hghub/hub_data_cached_configs.py +209 -0
  34. onnx_diagnostic/torch_models/hghub/model_inputs.py +3 -0
  35. onnx_diagnostic/torch_models/untrained/llm_tiny_llm.py +23 -50
  36. onnx_diagnostic/torch_models/{test_helper.py → validate.py} +158 -103
  37. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/METADATA +2 -2
  38. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/RECORD +41 -39
  39. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/WHEEL +0 -0
  40. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/licenses/LICENSE.txt +0 -0
  41. {onnx_diagnostic-0.6.3.dist-info → onnx_diagnostic-0.7.0.dist-info}/top_level.txt +0 -0
@@ -3,9 +3,9 @@ import os
3
3
  import requests
4
4
  import sys
5
5
  from pathlib import Path
6
- from typing import Any, Optional
6
+ from typing import Any, Optional, Union
7
7
  from urllib.parse import urlparse
8
- from onnx import helper, save_model, external_data_helper, ModelProto
8
+ from onnx import ModelProto, TensorProto
9
9
 
10
10
  CACHE_SUBDIR = "onnx-diagnostic"
11
11
 
@@ -114,87 +114,58 @@ def _make_model(self, model, verbose: int = 0):
114
114
  self.make_lm_head(module)
115
115
 
116
116
 
117
- def save_model_builder(self, out_dir: Optional[str] = "", verbose: int = 0) -> ModelProto:
117
+ def save_model_builder(
118
+ self, out_dir: Optional[str] = "", verbose: int = 0
119
+ ) -> Union[str, ModelProto]:
118
120
  """
119
121
  Saves a model created by function :func:`create_model_builder`.
120
122
  If out_dir is empty or not specified, the function still returns the
121
123
  generated model.
122
124
  """
123
- if verbose:
124
- print(f"[save_model_builder] Saving ONNX model in {out_dir}")
125
-
126
- # Create ONNX model
127
- model = helper.make_model(
128
- opset_imports=[
129
- self.clear_field(
130
- helper.make_operatorsetid("", 21 if self.quant_attrs["use_qdq"] else 14),
131
- "domain",
132
- ),
133
- helper.make_operatorsetid("com.microsoft", 1),
134
- ],
135
- ir_version=7,
136
- producer_name="onnxruntime-genai",
137
- producer_version="0.0.0",
138
- graph=self.make_graph(
139
- name="main_graph",
140
- inputs=self.inputs,
141
- outputs=self.outputs,
142
- initializer=self.initializers,
143
- value_info=self.value_infos,
144
- nodes=self.nodes,
145
- ),
146
- )
147
-
148
- # Load external data into ONNX model
149
- external_data_helper.load_external_data_for_model(model, self.cache_dir)
150
-
151
- # Delete external data files on disk before re-saving
152
- for path in os.listdir(self.cache_dir):
153
- if path.endswith(".bin"):
154
- os.remove(os.path.join(self.cache_dir, path))
125
+ import onnx_ir
155
126
 
156
- # Delete temporary cache dir if empty
157
- # if len(os.listdir(self.cache_dir)) == 0:
158
- # os.rmdir(self.cache_dir)
127
+ if verbose:
128
+ print(f"[save_model_builder] Saving ONNX model in {out_dir!r}")
159
129
 
160
- # Quantize ONNX model to desired precision
130
+ # Skip quantizing `MatMul` in `DequantizeLinear --> Transpose --> MatMul` path
161
131
  already_quantized_in_qdq_format = (
162
132
  self.quant_type is not None and self.quant_attrs["use_qdq"]
163
- ) # Skip quantizing `MatMul` in `DequantizeLinear --> Transpose --> MatMul` path
164
- if self.onnx_dtype == "int4" and not already_quantized_in_qdq_format:
165
- model = self.to_int4(model)
133
+ )
134
+ model = (
135
+ self.to_int4()
136
+ if self.onnx_dtype in {onnx_ir.DataType.INT4, onnx_ir.DataType.UINT4}
137
+ and not already_quantized_in_qdq_format
138
+ else self.model
139
+ )
140
+ model.graph.sort()
141
+ if not out_dir:
142
+ return onnx_ir.to_proto(model)
166
143
 
167
- # Save ONNX model with only one external data file and delete any existing duplicate copies
168
- if out_dir:
169
- out_path = os.path.join(out_dir, self.filename)
170
- data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data")
171
- if os.path.exists(out_path):
172
- if verbose:
173
- print(f"[save_model_builder] Overwriting {out_path!r}")
174
- os.remove(out_path)
175
- if os.path.exists(data_path):
176
- if verbose:
177
- print(f"[save_model_builder] Overwriting {data_path!r}")
178
- os.remove(data_path)
144
+ out_path = os.path.join(out_dir, self.filename)
145
+ data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data")
179
146
 
180
- if out_dir:
181
- location = os.path.basename(data_path)
182
- if os.path.exists(location):
183
- os.remove(location)
147
+ # Save ONNX model with only one external data file and delete any existing duplicate copies
148
+ out_path = os.path.join(out_dir, self.filename)
149
+ data_path = os.path.join(out_dir, os.path.basename(out_path) + ".data")
150
+ if os.path.exists(out_path):
184
151
  if verbose:
185
- print(f"[save_model_builder] out_path={out_path!r}")
186
- print(f"[save_model_builder] location={location!r}")
187
- save_model(
188
- model,
189
- out_path,
190
- save_as_external_data=True,
191
- all_tensors_to_one_file=True,
192
- location=location,
193
- size_threshold=1024,
194
- convert_attribute=False,
195
- )
196
- return None
197
- return model
152
+ print(f"[save_model_builder] Overwriting {out_path!r}")
153
+ os.remove(out_path)
154
+ if os.path.exists(data_path):
155
+ if verbose:
156
+ print(f"[save_model_builder] Overwriting {data_path!r}")
157
+ os.remove(data_path)
158
+
159
+ onnx_ir.save(
160
+ model,
161
+ out_path,
162
+ external_data=os.path.basename(data_path),
163
+ size_threshold_bytes=2**10,
164
+ )
165
+ if verbose:
166
+ print(f"[save_model_builder] saved in {out_dir!r}")
167
+
168
+ return out_path
198
169
 
199
170
 
200
171
  def create_model_builder(
@@ -335,13 +306,23 @@ def create_model_builder(
335
306
  for c in remove:
336
307
  delattr(config, c)
337
308
 
338
- onnx_model = cls(config, io_dtype, precision, execution_provider, cache_dir, extra_options)
309
+ convert = {
310
+ "fp32": TensorProto.FLOAT,
311
+ "fp16": TensorProto.FLOAT16,
312
+ "bfp16": TensorProto.BFLOAT16,
313
+ }
314
+ assert (
315
+ precision in convert
316
+ ), f"Unexpected value for precision={precision!r}, should be in {convert}"
317
+ onnx_model = cls(
318
+ config, io_dtype, convert[precision], execution_provider, cache_dir, extra_options
319
+ )
339
320
 
340
321
  if post:
341
322
  post(onnx_model)
342
323
  _make_model(onnx_model, model, verbose=verbose)
343
324
 
344
- assert onnx_model.nodes, (
325
+ assert onnx_model.model, (
345
326
  f"No node in the model, io_dtype={io_dtype!r}, "
346
327
  f"precision={precision!r}, execution_provider={execution_provider!r}, "
347
328
  f"extra_options={extra_options!r}, cache_dir={cache_dir!r}, "
@@ -16,6 +16,7 @@ from .cache_helper import (
16
16
  make_encoder_decoder_cache,
17
17
  make_sliding_window_cache,
18
18
  make_mamba_cache,
19
+ make_static_cache,
19
20
  )
20
21
  from .mini_onnx_builder import create_onnx_model_from_input_tensors
21
22
  from .onnx_helper import (
@@ -288,7 +289,8 @@ def steal_forward(
288
289
  """
289
290
  The necessary modification to steem forward method and prints out inputs
290
291
  and outputs using :func:`onnx_diagnostic.helpers.string_type`.
291
- See example :ref:`l-plot-tiny-llm-export`.
292
+ See example :ref:`l-plot-tiny-llm-export` or
293
+ :ref:`l-plot-intermediate-results`.
292
294
 
293
295
  :param model: a model or a list of models to monitor,
294
296
  every model can also be a tuple(name, model), name is displayed well.
@@ -410,12 +412,15 @@ def steal_forward(
410
412
  proto = create_onnx_model_from_input_tensors(storage)
411
413
  if verbose:
412
414
  print("-- dumps stored objects")
415
+ location = f"{os.path.split(dump_file)[-1]}.data"
416
+ if os.path.exists(location):
417
+ os.remove(location)
413
418
  onnx.save(
414
419
  proto,
415
420
  dump_file,
416
421
  save_as_external_data=True,
417
422
  all_tensors_to_one_file=True,
418
- location=f"{os.path.split(dump_file)[-1]}.data",
423
+ location=location,
419
424
  )
420
425
  if verbose:
421
426
  print("-- done dump stored objects")
@@ -723,6 +728,15 @@ def to_any(value: Any, to_value: Union[torch.dtype, torch.device, str]) -> Any:
723
728
  )
724
729
  )
725
730
  )
731
+ if value.__class__.__name__ == "StaticCache":
732
+ return make_static_cache(
733
+ list(
734
+ zip(
735
+ [t.to(to_value) for t in value.key_cache],
736
+ [t.to(to_value) for t in value.value_cache],
737
+ )
738
+ )
739
+ )
726
740
  if value.__class__.__name__ == "EncoderDecoderCache":
727
741
  return make_encoder_decoder_cache(
728
742
  to_any(value.self_attention_cache, to_value),
@@ -769,6 +783,8 @@ def torch_deepcopy(value: Any) -> Any:
769
783
  return make_dynamic_cache(
770
784
  torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
771
785
  )
786
+ if value.__class__.__name__ == "StaticCache":
787
+ return make_static_cache(torch_deepcopy(list(zip(value.key_cache, value.value_cache))))
772
788
  if value.__class__.__name__ == "SlidingWindowCache":
773
789
  return make_sliding_window_cache(
774
790
  torch_deepcopy(list(zip(value.key_cache, value.value_cache)))
@@ -1,3 +1,4 @@
1
1
  from .evaluator import ExtendedReferenceEvaluator
2
2
  from .ort_evaluator import OnnxruntimeEvaluator
3
3
  from .torch_evaluator import TorchOnnxEvaluator
4
+ from .report_results_comparison import ReportResultComparison
@@ -22,8 +22,11 @@ from ..helpers.ort_session import (
22
22
  InferenceSessionForNumpy,
23
23
  _InferenceSession,
24
24
  )
25
+ from ..helpers.torch_helper import to_tensor
26
+ from .report_results_comparison import ReportResultComparison
25
27
  from .evaluator import ExtendedReferenceEvaluator
26
28
 
29
+
27
30
  PROTO = (FunctionProto, ModelProto, GraphProto, NodeProto)
28
31
  Proto = Union[FunctionProto, ModelProto, GraphProto, NodeProto]
29
32
 
@@ -49,6 +52,8 @@ class OnnxruntimeEvaluator:
49
52
  :param ir_version: ir version to use when unknown
50
53
  :param opsets: opsets to use when unknown
51
54
  :param whole: if True, do not split node by node
55
+ :param torch_or_numpy: force the use of one of them, True for torch,
56
+ False for numpy, None to let the class choose
52
57
  """
53
58
 
54
59
  def __init__(
@@ -71,6 +76,7 @@ class OnnxruntimeEvaluator:
71
76
  ir_version: int = 10,
72
77
  opsets: Optional[Union[int, Dict[str, int]]] = None,
73
78
  whole: bool = False,
79
+ torch_or_numpy: Optional[bool] = None,
74
80
  ):
75
81
  if isinstance(proto, str):
76
82
  self.proto: Proto = load(proto)
@@ -102,8 +108,10 @@ class OnnxruntimeEvaluator:
102
108
  disable_aot_function_inlining=disable_aot_function_inlining,
103
109
  use_training_api=use_training_api,
104
110
  )
111
+ self.to_tensor_or_array = to_array_extended if not torch_or_numpy else to_tensor
105
112
 
106
113
  self.verbose = verbose
114
+ self.torch_or_numpy = torch_or_numpy
107
115
  self.sess_: Optional[_InferenceSession] = None
108
116
  if whole:
109
117
  self.nodes: Optional[List[NodeProto]] = None
@@ -122,7 +130,10 @@ class OnnxruntimeEvaluator:
122
130
  )
123
131
  )
124
132
  self.rt_inits_ = (
125
- {init.name: to_array_extended(init) for init in self.proto.graph.initializer}
133
+ {
134
+ init.name: self.to_tensor_or_array(init)
135
+ for init in self.proto.graph.initializer
136
+ }
126
137
  if hasattr(self.proto, "graph")
127
138
  else {}
128
139
  )
@@ -190,13 +201,14 @@ class OnnxruntimeEvaluator:
190
201
  return a
191
202
  device = f"D{a.get_device()}:" if hasattr(a, "detach") else ""
192
203
  if hasattr(a, "shape"):
204
+ prefix = "A:" if hasattr(a, "astype") else "T:"
193
205
  if self.verbose < 4: # noqa: PLR2004
194
- return f"{device}{a.dtype}:{a.shape} in [{a.min()}, {a.max()}]"
206
+ return f"{prefix}{device}{a.dtype}:{a.shape} in [{a.min()}, {a.max()}]"
195
207
  elements = a.ravel().tolist()
196
208
  if len(elements) > 10: # noqa: PLR2004
197
209
  elements = elements[:10]
198
- return f"{device}{a.dtype}:{a.shape}:{','.join(map(str, elements))}..."
199
- return f"{device}{a.dtype}:{a.shape}:{elements}"
210
+ return f"{prefix}{device}{a.dtype}:{a.shape}:{','.join(map(str, elements))}..."
211
+ return f"{prefix}{device}{a.dtype}:{a.shape}:{elements}"
200
212
  if hasattr(a, "append"):
201
213
  return ", ".join(map(self._log_arg, a))
202
214
  return a
@@ -214,6 +226,7 @@ class OnnxruntimeEvaluator:
214
226
  outputs: Optional[List[str]],
215
227
  feed_inputs: Dict[str, Any],
216
228
  intermediate: bool = False,
229
+ report_cmp: Optional[ReportResultComparison] = None,
217
230
  ) -> Union[Dict[str, Any], List[Any]]:
218
231
  """
219
232
  Runs the model.
@@ -222,6 +235,10 @@ class OnnxruntimeEvaluator:
222
235
  :param outputs: required outputs or None for all
223
236
  :param feed_inputs: inputs
224
237
  :param intermediate: returns all output instead of the last ones
238
+ :param report_cmp: used as a reference,
239
+ every intermediate results is compare to every existing one,
240
+ if not empty, it is an instance of
241
+ :class:`onnx_diagnostic.reference.ReportResultComparison`
225
242
  :return: outputs, as a list if return_all is False,
226
243
  as a dictionary if return_all is True
227
244
  """
@@ -267,6 +284,10 @@ class OnnxruntimeEvaluator:
267
284
  self._log(2, " + %s: %s", name, value) # type: ignore[arg-type]
268
285
  assert isinstance(name, str), f"unexpected type for name {type(name)}"
269
286
  results[name] = value
287
+ if report_cmp:
288
+ reported = report_cmp.report(dict(zip(node.output, outputs)))
289
+ if self.verbose > 1:
290
+ print(f" -- report {len(reported)} comparisons")
270
291
  if not intermediate:
271
292
  self._clean_unused_inplace(i_node, node, results)
272
293
 
@@ -426,6 +447,7 @@ class OnnxruntimeEvaluator:
426
447
  cls = (
427
448
  InferenceSessionForNumpy
428
449
  if any(isinstance(i, np.ndarray) for i in inputs)
450
+ and (not isinstance(self.torch_or_numpy, bool) or not self.torch_or_numpy)
429
451
  else InferenceSessionForTorch
430
452
  )
431
453
  try:
@@ -486,6 +508,7 @@ class OnnxruntimeEvaluator:
486
508
  verbose=self.verbose,
487
509
  ir_version=self.ir_version,
488
510
  opsets=self.opsets,
511
+ torch_or_numpy=self.torch_or_numpy,
489
512
  **self.session_kwargs,
490
513
  )
491
514
  return onx, sess
@@ -500,6 +523,7 @@ class OnnxruntimeEvaluator:
500
523
  verbose=self.verbose,
501
524
  ir_version=self.ir_version,
502
525
  opsets=self.opsets,
526
+ torch_or_numpy=self.torch_or_numpy,
503
527
  **self.session_kwargs,
504
528
  )
505
529
  return ev.proto, sess
@@ -575,6 +599,7 @@ class OnnxruntimeEvaluator:
575
599
  verbose=self.verbose,
576
600
  ir_version=self.ir_version,
577
601
  opsets=self.opsets,
602
+ torch_or_numpy=self.torch_or_numpy,
578
603
  whole=True,
579
604
  **self.session_kwargs,
580
605
  )
@@ -0,0 +1,95 @@
1
+ from typing import Any, Dict, List, Tuple, Union
2
+
3
+
4
+ ReportKeyNameType = Union[str, Tuple[str, int, str]]
5
+ ReportKeyValueType = Tuple[int, Tuple[int, ...]]
6
+
7
+
8
+ class ReportResultComparison:
9
+ """
10
+ Holds tensors a runtime can use as a reference to compare
11
+ intermediate results.
12
+ See :meth:`onnx_diagnostic.reference.TorchOnnxEvaluator.run`.
13
+
14
+ :param tensors: tensor
15
+ """
16
+
17
+ def __init__(self, tensors: Dict[ReportKeyNameType, "torch.Tensor"]): # noqa: F821
18
+ from ..helpers.onnx_helper import dtype_to_tensor_dtype
19
+ from ..helpers import max_diff, string_type
20
+
21
+ assert all(
22
+ hasattr(v, "shape") and hasattr(v, "dtype") for v in tensors.values()
23
+ ), f"One of the tensors is not: {string_type(tensors, with_shape=True)}"
24
+ self.dtype_to_tensor_dtype = dtype_to_tensor_dtype
25
+ self.max_diff = max_diff
26
+ self.tensors = tensors
27
+ self._build_mapping()
28
+
29
+ def key(self, tensor: "torch.Tensor") -> ReportKeyValueType: # noqa: F821
30
+ "Returns a key for a tensor, (onnx dtype, shape)."
31
+ return self.dtype_to_tensor_dtype(tensor.dtype), tuple(map(int, tensor.shape))
32
+
33
+ def _build_mapping(self):
34
+ mapping = {}
35
+ for k, v in self.tensors.items():
36
+ key = self.key(v)
37
+ if key not in mapping:
38
+ mapping[key] = []
39
+ mapping[key].append(k)
40
+ self.mapping = mapping
41
+ self.clear()
42
+
43
+ def clear(self):
44
+ """Clears the last report."""
45
+ self.report_cmp = {}
46
+ self.unique_run_names = set()
47
+
48
+ @property
49
+ def value(
50
+ self,
51
+ ) -> Dict[Tuple[Tuple[int, str], ReportKeyNameType], Dict[str, Union[float, str]]]:
52
+ "Returns the report."
53
+ return self.report_cmp
54
+
55
+ @property
56
+ def data(self) -> List[Dict[str, Any]]:
57
+ "Returns data which can be consumed by a dataframe."
58
+ rows = []
59
+ for k, v in self.value.items():
60
+ (i_run, run_name), ref_name = k
61
+ d = dict(run_index=i_run, run_name=run_name, ref_name=ref_name)
62
+ d.update(v)
63
+ rows.append(d)
64
+ return rows
65
+
66
+ def report(
67
+ self, outputs: Dict[str, "torch.Tensor"] # noqa: F821
68
+ ) -> List[Tuple[Tuple[int, str], ReportKeyNameType, Dict[str, Union[float, str]]]]:
69
+ """
70
+ For every tensor in outputs, compares it to every tensor held by
71
+ this class if it shares the same type and shape. The function returns
72
+ the results of the comparison. The function also collects the results
73
+ into a dictionary the user can retrieve later.
74
+ """
75
+ res: List[Tuple[Tuple[int, str], ReportKeyNameType, Dict[str, Union[float, str]]]] = []
76
+ for name, tensor in outputs.items():
77
+ i_run = len(self.unique_run_names)
78
+ self.unique_run_names.add(name)
79
+ key = self.key(tensor)
80
+ if key not in self.mapping:
81
+ continue
82
+ cache: Dict["torch.device", "torch.Tensor"] = {} # noqa: F821, UP037
83
+ for held_key in self.mapping[key]:
84
+ t2 = self.tensors[held_key]
85
+ if hasattr(t2, "device") and hasattr(tensor, "device"):
86
+ if t2.device in cache:
87
+ t = cache[t2.device]
88
+ else:
89
+ cache[t2.device] = t = tensor.to(t2.device)
90
+ diff = self.max_diff(t, t2)
91
+ else:
92
+ diff = self.max_diff(tensor, t2)
93
+ res.append((i_run, name, held_key, diff)) # type: ignore[arg-type]
94
+ self.report_cmp[(i_run, name), held_key] = diff
95
+ return res
@@ -5,6 +5,7 @@ import onnx
5
5
  import torch
6
6
  from ..helpers.torch_helper import to_tensor
7
7
  from ..torch_onnx.runtime_info import first_used_last_used, RuntimeValue
8
+ from .report_results_comparison import ReportResultComparison
8
9
  from . import torch_ops
9
10
 
10
11
 
@@ -455,12 +456,17 @@ class TorchOnnxEvaluator:
455
456
  self,
456
457
  outputs: Optional[List[str]],
457
458
  feeds: Union[Dict[str, torch.Tensor], Dict[str, np.ndarray]],
459
+ report_cmp: Optional[ReportResultComparison] = None,
458
460
  ) -> Union[List[Optional[torch.Tensor]], List[Optional[np.ndarray]]]:
459
461
  """
460
462
  Runs the ONNX model.
461
463
 
462
464
  :param outputs: outputs required
463
465
  :param feeds: inputs
466
+ :param report_cmp: used as a reference,
467
+ every intermediate results is compare to every existing one,
468
+ if not empty, it is an instance of
469
+ :class:`onnx_diagnostic.reference.ReportResultComparison`
464
470
  :return: output tensors.
465
471
  """
466
472
  use_numpy = any(isinstance(t, np.ndarray) for t in feeds.values())
@@ -532,6 +538,21 @@ class TorchOnnxEvaluator:
532
538
  f"+R {kernel.output[0]}: "
533
539
  f"{self.runtime_info[kernel.output[0]].string_type()}"
534
540
  )
541
+ if report_cmp:
542
+ reported = report_cmp.report(
543
+ dict(
544
+ zip(
545
+ kernel.output,
546
+ (
547
+ tuple((r.tensor if r else None) for r in res) # type: ignore[attr-defined]
548
+ if isinstance(res, tuple)
549
+ else ((res.tensor if res else None),) # type: ignore[attr-defined]
550
+ ),
551
+ )
552
+ )
553
+ )
554
+ if self.verbose > 1:
555
+ print(f" -- report {len(reported)} comparisons")
535
556
 
536
557
  # free intermediate results
537
558
  for name in self.last_used[it]:
@@ -69,6 +69,9 @@ def get_inputs(
69
69
  use_cache:bool,return_dict:bool
70
70
  )
71
71
  """
72
+ assert (
73
+ "cls_cache" not in kwargs
74
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
72
75
  batch = torch.export.Dim("batch", min=1, max=1024)
73
76
  seq_length = "seq_length"
74
77
 
@@ -35,6 +35,9 @@ def get_inputs(
35
35
  token_type_ids:T7s1x13[0,0:A0.0],
36
36
  attention_mask:T7s1x13[1,1:A1.0])
37
37
  """
38
+ assert (
39
+ "cls_cache" not in kwargs
40
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
38
41
  batch = torch.export.Dim("batch", min=1, max=1024)
39
42
  seq_length = "sequence_length"
40
43
  shapes = {
@@ -35,6 +35,9 @@ def get_inputs(
35
35
  token_type_ids:T7s1x13[0,0:A0.0],
36
36
  attention_mask:T7s1x13[1,1:A1.0])
37
37
  """
38
+ assert (
39
+ "cls_cache" not in kwargs
40
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
38
41
  batch = torch.export.Dim("batch", min=1, max=1024)
39
42
  seq_length = "sequence_length"
40
43
  shapes = {
@@ -48,11 +48,14 @@ def get_inputs(
48
48
  :param input_height: input height
49
49
  :return: dictionary
50
50
  """
51
+ assert (
52
+ "cls_cache" not in kwargs
53
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
51
54
  assert isinstance(
52
55
  input_width, int
53
56
  ), f"Unexpected type for input_width {type(input_width)}{config}"
54
57
  assert isinstance(
55
- input_width, int
58
+ input_height, int
56
59
  ), f"Unexpected type for input_height {type(input_height)}{config}"
57
60
 
58
61
  shapes = {
@@ -67,6 +70,9 @@ def get_inputs(
67
70
  -1, 1
68
71
  ),
69
72
  )
73
+ if model.__class__.__name__ == "ViTForImageClassification":
74
+ inputs["interpolate_pos_encoding"] = True
75
+ shapes["interpolate_pos_encoding"] = None # type: ignore[assignment]
70
76
  res = dict(inputs=inputs, dynamic_shapes=shapes)
71
77
  if add_second_input:
72
78
  res["inputs2"] = get_inputs(
@@ -52,6 +52,9 @@ def get_inputs(
52
52
  :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
53
53
  :return: dictionary
54
54
  """
55
+ assert (
56
+ "cls_cache" not in kwargs
57
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
55
58
  batch = torch.export.Dim("batch", min=1, max=1024)
56
59
  seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
57
60
  cache_length = "cache_length" # torch.export.Dim("cache_length", min=1, max=4096)
@@ -61,6 +61,9 @@ def get_inputs(
61
61
  :param dynamic_rope: use dynamic rope (see :class:`transformers.LlamaConfig`)
62
62
  :return: dictionary
63
63
  """
64
+ assert (
65
+ "cls_cache" not in kwargs
66
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
64
67
  assert not add_second_input, "add_second_input=True not yet implemented"
65
68
  raise NotImplementedError(f"get_inputs not yet implemented for task {__TASK__!r}.")
66
69
 
@@ -41,6 +41,9 @@ def get_inputs(
41
41
  :param input_height: input height
42
42
  :return: dictionary
43
43
  """
44
+ assert (
45
+ "cls_cache" not in kwargs
46
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
44
47
  assert isinstance(
45
48
  input_width, int
46
49
  ), f"Unexpected type for input_width {type(input_width)}{config}"
@@ -35,6 +35,9 @@ def get_inputs(
35
35
  token_type_ids:T7s1x13[0,0:A0.0],
36
36
  attention_mask:T7s1x13[1,1:A1.0])
37
37
  """
38
+ assert (
39
+ "cls_cache" not in kwargs
40
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
38
41
  batch = torch.export.Dim("batch", min=1, max=1024)
39
42
  seq_length = "seq_length"
40
43
  shapes = {
@@ -62,6 +62,9 @@ def get_inputs(
62
62
  decoder_input_ids:T7s1x1,
63
63
  encoder_outputs:dict(last_hidden_state:T1s1x16x512)
64
64
  """
65
+ assert (
66
+ "cls_cache" not in kwargs
67
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
65
68
  batch = torch.export.Dim("batch", min=1, max=1024)
66
69
  seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
67
70
  cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
@@ -64,6 +64,9 @@ def get_inputs(
64
64
  decoder_input_ids:T7s1x1,
65
65
  encoder_outputs:dict(last_hidden_state:T1s1x16x512)
66
66
  """
67
+ assert (
68
+ "cls_cache" not in kwargs
69
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
67
70
  batch = torch.export.Dim("batch", min=1, max=1024)
68
71
  seq_length = "seq_length" # torch.export.Dim("seq_length", min=1, max=4096)
69
72
  cache_length = "cache_length_key" # torch.export.Dim("cache_length", min=1, max=4096)
@@ -35,6 +35,9 @@ def get_inputs(
35
35
  token_type_ids:T7s1x13[0,0:A0.0],
36
36
  attention_mask:T7s1x13[1,1:A1.0])
37
37
  """
38
+ assert (
39
+ "cls_cache" not in kwargs
40
+ ), f"Not yet implemented for cls_cache={kwargs['cls_cache']!r}."
38
41
  batch = torch.export.Dim("batch", min=1, max=1024)
39
42
  seq_length = "seq_length" # torch.export.Dim("sequence_length", min=1, max=1024)
40
43
  shapes = {